-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathatom.xml
481 lines (292 loc) · 305 KB
/
atom.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title>一起打怪升级呀</title>
<icon>https://www.gravatar.com/avatar/2555127dc0de830d31ceeb98d8565ac8</icon>
<subtitle>别整太大鸭力,多鸡立自己qaq</subtitle>
<link href="/atom.xml" rel="self"/>
<link href="https://blog.nicehuster.cn/"/>
<updated>2024-03-07T07:55:18.888Z</updated>
<id>https://blog.nicehuster.cn/</id>
<author>
<name>nicehuster</name>
<email>[email protected]</email>
</author>
<generator uri="http://hexo.io/">Hexo</generator>
<entry>
<title>Sora技术系列-NaViT</title>
<link href="https://blog.nicehuster.cn/2024/02/26/NaViT/"/>
<id>https://blog.nicehuster.cn/2024/02/26/NaViT/</id>
<published>2024-02-26T11:13:39.000Z</published>
<updated>2024-03-07T07:55:18.888Z</updated>
<content type="html"><![CDATA[<p>在sora的技术报告中提到将视觉数据转换为patches,这个不得不提及ViT这篇开创之作,它便是通过将图像划分为多个patch,然后映射成token序列,输入到transformer完成一些视觉任务,但是,在划分patch时通常需要将图像调整为固定的分辨率进行处理,这种方法在某种程度上是次优的。NaViT这篇论文就是解决不同分辨率输入的问题。Patch n’ Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution,这篇论文是发表在NeuralPS2023上的一篇论文,Google DeepMind的工作,下面详细介绍NaViT工作原理。</p><a id="more"></a><h4 id="NaViT"><a href="#NaViT" class="headerlink" title="NaViT"></a>NaViT</h4><blockquote><p>title:<a href="https://arxiv.org/abs/2307.06304" target="_blank" rel="noopener">Patch n’ Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution</a></p><p>author:Mostafa Dehghani, Basil Mustafa, Josip Djolonga,and etc.</p></blockquote><p>NaViT是一种新的视觉Transformer,通过在训练过程中使用序列打包来处理任意分辨率和宽高比的输入,从而在训练效率、模型适应性和推理灵活性方面超越了传统的Vision Transformer。</p><blockquote><ol><li>Randomly sampling resolutions at training time significantly reduces training cost. 显著降低训练成本</li><li>NaViT results in high performance across a wide range of resolutions, enabling smooth cost-performance trade-off at inference time, and can be adapted with less cost to new tasks. 取得更好的效果,而且以更小的代价适用其他任务上</li></ol></blockquote><p><img src="/img/NaViT.png" alt></p><p>上述过程,展示了NaViT在处理任意分辨率的方法,Data preprocessing阶段,先将不同分辨率图片进行patch处理,再采用了token drop操作随机丢弃一些patch,类似dropout,目的是加速训练;预处理完后,把三张图片生成的patches拉平为一个序列,不够的地方用pad填充;在Self-Attention阶段,使用attention mask技术防止图片之间存在信息交换;</p><h5 id="Architectural-changes"><a href="#Architectural-changes" class="headerlink" title="Architectural changes"></a>Architectural changes</h5><p>NaViT 的架构是建立在 ViT的基础上的,但是又做些修改:</p><blockquote><ol><li>Masked self attention and masked pooling,防止示例相互关注,引入了额外的attention mask;</li><li>Factorized & fractional positional embeddings,支持可变宽高比并很容易外推到没见过的图片分辨率;</li></ol></blockquote><h5 id="Training-changes"><a href="#Training-changes" class="headerlink" title="Training changes"></a>Training changes</h5><p>在训练NaViT时,也引入了一些新的trick:</p><blockquote><ol><li>Continuous Token dropping,packing enables continuous token dropping, whereby the token dropping rate can be varied per-image.</li><li>Resolution sampling. it allows mixed-resolution training by sampling from a distribution of image sizes, while retaining each images’ original aspect ratio.</li></ol></blockquote><h5 id="Improved-training-efficiency-and-performance"><a href="#Improved-training-efficiency-and-performance" class="headerlink" title="Improved training efficiency and performance"></a>Improved training efficiency and performance</h5><p>从实验结果上看,NaViT相比于ViT,训练速度是ViT的四倍,而且性能更好,推理速度也更快;</p><p><img src="/img/NaViT-exp.png" alt></p>]]></content>
<summary type="html">
<p>在sora的技术报告中提到将视觉数据转换为patches,这个不得不提及ViT这篇开创之作,它便是通过将图像划分为多个patch,然后映射成token序列,输入到transformer完成一些视觉任务,但是,在划分patch时通常需要将图像调整为固定的分辨率进行处理,这种方法在某种程度上是次优的。NaViT这篇论文就是解决不同分辨率输入的问题。Patch n’ Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution,这篇论文是发表在NeuralPS2023上的一篇论文,Google DeepMind的工作,下面详细介绍NaViT工作原理。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="Sora" scheme="https://blog.nicehuster.cn/tags/Sora/"/>
</entry>
<entry>
<title>Sora涉及的核心技术</title>
<link href="https://blog.nicehuster.cn/2024/02/23/sora/"/>
<id>https://blog.nicehuster.cn/2024/02/23/sora/</id>
<published>2024-02-23T11:13:39.000Z</published>
<updated>2024-03-07T07:55:11.474Z</updated>
<content type="html"><![CDATA[<p>鉴于OpenAI近期发布的视频生成模型Sora大火,紧跟前沿技术,赶紧学习了一下其官网公开的tech report中Sora涉及的一些技术,这里整理了一些核心的相关论文,接下来在工作之余会抽空学习一下。</p><p><img src="/img/sora.png" alt></p>]]></content>
<summary type="html">
<p>鉴于OpenAI近期发布的视频生成模型Sora大火,紧跟前沿技术,赶紧学习了一下其官网公开的tech report中Sora涉及的一些技术,这里整理了一些核心的相关论文,接下来在工作之余会抽空学习一下。</p>
<p><img src="/img/sora.png" alt
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="Sora" scheme="https://blog.nicehuster.cn/tags/Sora/"/>
</entry>
<entry>
<title>图文多模态理解-BLIP系列</title>
<link href="https://blog.nicehuster.cn/2023/04/27/%E5%9B%BE%E6%96%87%E5%A4%9A%E6%A8%A1%E6%80%81%E7%90%86%E8%A7%A3%E4%B9%8BBLIP%E7%B3%BB%E5%88%97/"/>
<id>https://blog.nicehuster.cn/2023/04/27/图文多模态理解之BLIP系列/</id>
<published>2023-04-27T11:13:39.000Z</published>
<updated>2024-03-07T07:53:39.584Z</updated>
<content type="html"><![CDATA[<p> 这篇文章介绍多模态预训练的一个系列-BLIP,以及针对BLIP改进和延续的一些相关工作。</p><h3 id="BLIP"><a href="#BLIP" class="headerlink" title="BLIP"></a>BLIP</h3><p>这篇也是做的多模态预训练任务,之前的工作都是在于统一vision-language的理解任务,这篇工作同时支持理解和生成任务。这篇文章的作者和ALBEF的作者是同一人。</p><a id="more"></a><h4 id="motivation"><a href="#motivation" class="headerlink" title="motivation"></a>motivation</h4><blockquote><p>模型角度:现有方法分为encoder-based model和encoder-decoder model, 都存在一些问题,前者无法支持下游的生成任务,例如CLIP不支持image-caption;后者比如SimVLM在image-text retreval上效果差。</p><p>数据角度:使用互联网上爬取到的图像文本对,含有很多噪声;</p></blockquote><h4 id="BLIP-1"><a href="#BLIP-1" class="headerlink" title="BLIP"></a>BLIP</h4><p><img src="/img/blip.png" alt></p><p>上图的模型结构中包含了四个部分image encoder, text encoder, image-grounded text encoder, image-grounded text decoder. image encoder, text encoder分别使用的是ViT和Bert提取图像特征和文本特征。image-grounded text encoder引入图像特征做cross attention,用来做图像文本匹配(ITM)任务。 image-grounded text decoder不同于前者,将self attention替换成causal self-attention用于语言模型任务。</p><p>需要注意的是,与text相关的text encoder和Image-grounded Text encoder的共有结构特征是共享的,为了标记差异,在文本的开头分别用”[CLS]”和”[Encoder]”标记。而Image-grounded Text decoder中使用”[Decoder]”。</p><h4 id="Pre-training-Objectives"><a href="#Pre-training-Objectives" class="headerlink" title="Pre-training Objectives"></a>Pre-training Objectives</h4><p>在预训练中有三个目标函数,两个是基于理解的预训练任务以及一个基于生成的预训练任务。计算量比较大的image encoder只需要运算一次。</p><blockquote><ul><li><strong>Image-Text Contrastive Loss(ITC)</strong>:目的都是为了对齐视觉和文本模态的特征;</li><li><strong>Image-Text Matching Loss (ITM)</strong>:判断图像和文本是否匹配,二分类任务;</li><li><strong>Language Modeling Loss (LM)</strong>:不同于MLM任务,这里使用的是NTP;</li></ul></blockquote><h4 id="CapFilt"><a href="#CapFilt" class="headerlink" title="CapFilt"></a>CapFilt</h4><p>在BLIP中为了提升预训练数据的质量,作者设计了CapFlit,在预训练任务中包含生成字幕的预训练任务也有判断图文是否匹配的预训练任务,因此可以让模型生成图片的描述(Captioner),再通过Filter用于判断图像和文本是否匹配。如下图,原始的图文不匹配,在最终预训练时会被过滤掉,而Captioner生成的文本和图片匹配,则在最终预训练时会保留生成的数据。</p><p><img src="/img/capFlit.png" alt></p><p>文中也展示了一些对比的case,从case上看合成的caption比直接从web爬取的caption质量明显好很多。</p><p><img src="/img/CapFilt_exam.png" alt></p><h3 id="BLIP-2"><a href="#BLIP-2" class="headerlink" title="BLIP-2"></a>BLIP-2</h3><p>同样出自Junnan Li, BLIP2新增了一个Querying Transformer (Q-Former),BLIP2训练需要two stage训练。第一个预训练阶段,我们执行vision-language representation learning,强制Q-Former学习与文本最相关的视觉表示。在第二个预训练阶段,我们通过将Q-Former的输出连接到冻结的LLM来执行视觉到语言的生成 学习,并训练Q-Former,使其输出视觉表示可以被LLM解释。如下图所示:</p><p><img src="/img/blip2.png" alt></p><h4 id="motivation-1"><a href="#motivation-1" class="headerlink" title="motivation"></a>motivation</h4><blockquote><ul><li>The cost of vision-and-language pre-training has become increasingly prohibitive due to end-to-end training of large-scale models;</li><li>Pre-trained vision models offer high-quality visual representation.</li><li>Pre-trained language models, in particular large language models (LLMs), offer strong language generation and zero-shot transfer abilities.</li></ul></blockquote><p>一方面,现有的视觉-语言预训练模型越做越大,使得计算成本不断增加;另一方面,预训练的视觉模型/语言模型具有很强的能力,因此作者想到使用frozen的预训练的视觉/语言模型来做视觉-语言对齐的预训练任务。</p><h4 id="method"><a href="#method" class="headerlink" title="method"></a>method</h4><p>训练Q-Former需要两步,分别是vision-language representation learning stage和vision-to-language generative learning stage。</p><ul><li><strong>vision-language representation learning stage</strong></li></ul><p><img src="/img/blip2-stage1.png" alt></p><p>QFormer由Image Transformer和Text Transformer两个子模块构成,它们共享相同自注意力层。QFormer使用= BERTbase 的预训练权重初始化,而交叉注意力层是随机初始化。 Q-Former 总共包含 188M 参数。QFormer拥有 32个query,768维,是远小于ViT-L/14的 257x1024维度的。image encoder是冻结的。和BLIP类似,有三个优化函数:</p><blockquote><ul><li><strong>Image-Text Contrastive Loss(ITC)</strong>:目的都是为了对齐image representation和text representation;这里的image representation是输出的query representation(32x768),text representation是text transformer CLS token。计算互信息最大的那个query做梯度反传。为了防止信息泄露,query和text不能互相看见。</li><li><strong>Image-Text Matching Loss (ITM)</strong>:同BLIP,query和text互相都可以看见,做了更细粒度的匹配。</li><li><strong>Image-grounded Text Generation</strong>:训练Q-Former在给定图像情况下,生成文字。这里没有显式输入图像信息,而是与learnable query进行交互,text 可以看到query 和 当前和历史的text,query还是只能看到query。这里强迫了learnable query必须summary图像的抽象信息。</li></ul></blockquote><ul><li><strong>vision-to-language generative learning stage</strong></li></ul><p>在第一个阶段,已经训练得到了一个Q-Former,可以提取图像全局特征和重要的信息。第二个阶段,Q-Former被接入到LLM上,获取生成语言的能力。</p><p><img src="/img/blip2-stage2.png" alt></p><p>首先使用一个FC对齐Q-Former的维度和LLM text embedding维度。然后可以把图像浓缩信息传入到LLM中,因为第一个阶段有<strong>ITG</strong>来监督文本的generation,因此这个图像info天然的可以直接用于LLM,第一阶段已经做了微对齐。</p><h4 id="experiments"><a href="#experiments" class="headerlink" title="experiments"></a>experiments</h4><p>微调带来了第一个好处就是机器成本下降,文中提到”For example, using a single 16-A100(40G) machine, our largest model with ViT-g and FlanT5-XXL requires less than 6 days for the first stage and less than 3 days for the second stage.”。 BLIP2 也展现了强大的zeroshot vision-language 任务。</p><p><img src="/img/blip2-zero-shot-exp.png" alt></p><h3 id="InstructBLIP"><a href="#InstructBLIP" class="headerlink" title="InstructBLIP"></a>InstructBLIP</h3>]]></content>
<summary type="html">
<p> 这篇文章介绍多模态预训练的一个系列-BLIP,以及针对BLIP改进和延续的一些相关工作。</p>
<h3 id="BLIP"><a href="#BLIP" class="headerlink" title="BLIP"></a>BLIP</h3><p>这篇也是做的多模态预训练任务,之前的工作都是在于统一vision-language的理解任务,这篇工作同时支持理解和生成任务。这篇文章的作者和ALBEF的作者是同一人。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="BLIP" scheme="https://blog.nicehuster.cn/tags/BLIP/"/>
</entry>
<entry>
<title>Prompt详解</title>
<link href="https://blog.nicehuster.cn/2023/03/27/prompt/"/>
<id>https://blog.nicehuster.cn/2023/03/27/prompt/</id>
<published>2023-03-27T11:13:39.000Z</published>
<updated>2023-04-02T08:20:04.845Z</updated>
<content type="html"><![CDATA[<p>近年来,在NLP领域热度最高的技术莫过于prompt engineering ,想了解一个方向最快速的方法就是看有关这个方向的survey的paper。本文的内容主要参考CMU刘鹏飞的这篇论文:<a href="https://arxiv.org/pdf/2107.13586.pdf" target="_blank" rel="noopener">Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in Natural Language Processing</a>。这篇survey对于比较清晰的介绍了当前NLP中的范式发展,以及prompt的一些基础知识和prompt的设计方法。</p><p><img src="/img/prompt.jpg" alt></p><a id="more"></a><h4 id="NLP的范式发展"><a href="#NLP的范式发展" class="headerlink" title="NLP的范式发展"></a>NLP的范式发展</h4><p>这篇论文总结总结了在NLP中的四次”范式“的“革命”:</p><p><img src="/img/four-paradigms-in-nlp.png" alt></p><blockquote><ol><li>第一个范式:feature engineering,从原始数据中提取显著特征,并提供具有适当归纳偏差的模型,以便从这个有限的数据中学习;</li><li>第二个范式:architecture engineering,通过设计有利于学习这些特征的合适网络架构来提供归纳偏差</li><li>第三个范式:objective engineering,在大量的原始文本数据上对大模型进行<strong>pretrain</strong>学习通用特征,然后在下游任务上进行<strong>fine-tune</strong>,在这个范式中重点在于设计合适的pretrain和fine-tune的objective function;</li><li>第四个范式:prompt engineering,随着PLM体量不断增大,对其进行fine-tune对硬件、数据、耗时代价要求也在不断上涨。而prompt就是一个更小巧轻量、更普适高效的方法;</li></ol></blockquote><h4 id="prompt是什么"><a href="#prompt是什么" class="headerlink" title="prompt是什么"></a>prompt是什么</h4><p>prompt 说简单也简单,其实就是构建一个语言模版。融入了prompt的新模式大致可以归纳成”pre-train, prompt, and predict“。在该模式中,下游任务被重新调整成类似预训练任务的形式。例如,通常的预训练任务有Masked Language Model, 在文本情感分类任务中,对于 “I love this movie.” 这句输入,可以在后面加上prompt “The movie is <em>_</em>“ 这样的形式,然后让PLM用表示情感的答案填空如 “great”、”fantastic” 等等,最后再将该答案转化成情感分类的标签,这样以来,通过选取合适的prompt,我们可以控制模型预测输出,从而一个完全无监督训练的PLM可以被用来解决各种各样的下游任务。</p><p>下表格是人工设计的prompt模板,其中[x]和[y]可以看作是数据和标签。可以看到,prompt的微小差别,其性能差异可大到20-30点,合适的prompt对于模型的效果至关重要。</p><p><img src="/img/prompt-case-study.png" alt></p><h4 id="prompt数学描述"><a href="#prompt数学描述" class="headerlink" title="prompt数学描述"></a>prompt数学描述</h4><p>对于传统有监督学习任务而言,我们的目标是对x/y进行建模,得到模型$P(y|x,\theta)$,x/y为对应的数据和标签,然而在现实世界中构建大量人工标注的x/y数据,往往费时费力,而且质量也无法保证。而基于prompt的方法则是通过试图学习LM来规避这个问题。LM可表示为$P(X|,\theta)$,是对文本x的直接建模,通过它来直接预测/生成y。这在一定程度上对$P(y|x,\theta)$的直接建模进行了”解耦“,这样也就不再依赖人工标注数据x/y了。对于输入文本$x$,有函数$f_{prompt}(x)$ 将$x$转化为prompt形式$x^{’}$ ,</p><script type="math/tex; mode=display">x^{’} =f_{prompt}(x)</script><p>该函数通常会进行两步操作:</p><blockquote><ol><li>使用一个模板,模板通常为一段自然语言,并且包含有两个空位置:用于填输入x的位置[X]和用于生成答案文本z的位置[Z].</li><li>把输入x填到[X]的位置。</li></ol></blockquote><p>还用前文提到的例子。在文本情感分类的任务中,假设输入是</p><figure class="highlight ini"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="attr">x</span> = <span class="string">" I love this movie."</span></span><br></pre></td></tr></table></figure><p>使用的模板是</p><figure class="highlight css"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">" <span class="selector-attr">[X]</span> <span class="selector-tag">Overall</span>, <span class="selector-tag">it</span> <span class="selector-tag">was</span> <span class="selector-tag">a</span> <span class="selector-attr">[Z]</span> <span class="selector-tag">movie</span>."</span><br></pre></td></tr></table></figure><p>那么得到的x′就应该是 “I love this movie. Overall it was a [Z] movie.”</p><p>在实际的研究中,prompts应该有空位置来填充答案,这个位置一般在句中或者句末。如果在句中,一般称这种prompt为<strong>cloze prompt</strong>;如果在句末,一般称这种prompt为<strong>prefix prompt</strong>。[X]和[Z]的位置以及数量都可能对结果造成影响,因此可以根据需要灵活调整。上面讲的都是简单的情感分类任务的 Prompt 设计,读者看到这里自然而然的会想到,其他 NLP 任务的 Prompt 如何设计呢?实际上刘鹏飞大神在他的<a href="https://arxiv.org/abs/2107.13586" target="_blank" rel="noopener">论文</a>中给我们提供了一些参考</p><p><img src="/img/example-prompt-for-different-task.png" alt></p><h4 id="prompt设计"><a href="#prompt设计" class="headerlink" title="prompt设计"></a>prompt设计</h4><p><img src="/img/prompt-engineering.png" alt></p><p>如上图所示,prompt主要从俩方面进行设计:形状以及手工/自动设计。</p><h5 id="prompt-shape"><a href="#prompt-shape" class="headerlink" title="prompt shape"></a>prompt shape</h5><p>prompt的形状主要有两种:cloze prompt,与 Maksed Language Model 的训练方式非常类似,因此对于 MLM 任务来说,Cloze Prompt 更合适;对于生成任务或者使用自回归 LM 解决的任务,Prefix Prompt 更合适;</p><h5 id="Hand-crafted"><a href="#Hand-crafted" class="headerlink" title="Hand-crafted"></a>Hand-crafted</h5><p>Prompt 的模板最开始是人工设计的,人工设计一般基于人类的自然语言知识,力求得到语义流畅且高效的「模板」。人工设计模板的优点是直观,但缺点是需要很多实验、经验以及语言专业知识。下图是 <a href="https://arxiv.org/abs/2103.10385" target="_blank" rel="noopener">GPT Understands, Too</a> 论文中的一个实验结果。</p><p><img src="/img/prompt-case-study.png" alt></p><h5 id="Automated"><a href="#Automated" class="headerlink" title="Automated"></a>Automated</h5><p>为了解决人工设计模板的缺点,许多研究员开始探究如何自动学习到合适的模板。自动学习的模板又可以分为离散(Discrete Prompts)和连续(Continuous Prompts)两大类。离散的prompt是使用具体的words/tokens,而连续的prompt则是基于embeddings来表示prompts。这里主要介绍一下连续的prompt。连续型prompts去掉了两个约束条件:</p><blockquote><ol><li>relax the constraint that the embeddings of template words be the embeddings of natural language words;</li><li>Remove the restriction that the template is parameterized by the pre-trained LM’s parameters;</li></ol></blockquote><p>连续prompts好处是模板中的embedding可以是整个词表的embedding,而不再是有限的一些embedding,此外,模板的参数也不再是直接取PLM的参数,而是由独立的参数,可通过下游任务的数据训练进行调整。目前的连续prompts方法大致可以分为下面几种:</p><h6 id="Prefix-Tuning"><a href="#Prefix-Tuning" class="headerlink" title="Prefix Tuning"></a>Prefix Tuning</h6><p>prefix tuning 最开始由 Li 在<a href="https://arxiv.org/abs/2101.00190" target="_blank" rel="noopener">Prefix-Tuning: Optimizing Continuous Prompts for Generation</a> 这篇论文中提出来的,是一种在输入句子前添加一组连续型向量的方法,该方法保持 PLM 的参数不动,仅训练前缀(Prefix)向量。它的形式化定义是,在给定一个可训练的前缀矩阵$M_{\phi}$和一个固定的参数化为$\theta$的PLM的对数似然目标上进行优化,</p><script type="math/tex; mode=display">maxφ log P (y|x; θ; φ) = maxφ∑yilog P (yi|h<i; θ; φ)</script><p>也是属于一种PEFT方法,我会在下一篇博客中详细介绍这个方法。</p><h6 id="Tuning-Initialized-with-Discrete-Prompts"><a href="#Tuning-Initialized-with-Discrete-Prompts" class="headerlink" title="Tuning Initialized with Discrete Prompts"></a>Tuning Initialized with Discrete Prompts</h6><p>这类方法中连续prompts是用已有的prompts初始化的,已有的prompts可以是手工设计的,也可以是之前搜索发现的离散prompts。Zhong 等人先用一个离散prompt搜索方法定义了一个模板,然后基于该模板初始化虚拟的token,最后微调这些token的embedding以提高准确率。</p><h6 id="Hard-Soft-Prompt-Hybrid-Tuning"><a href="#Hard-Soft-Prompt-Hybrid-Tuning" class="headerlink" title="Hard-Soft Prompt Hybrid Tuning"></a>Hard-Soft Prompt Hybrid Tuning</h6><p>这类方法可以说是手工设计和自动学习的结合,它通常不单纯使用可学习的prompt模板,而是在手工设计的模板中插入一些可学习的embedding。</p><h4 id="最后"><a href="#最后" class="headerlink" title="最后"></a>最后</h4><p>本文只是简单地介绍了一下prompt原理,详细可看原文。此外,如何使用prompt预训练大模型可以参考b站上这个视频,讲解地比较通俗易懂,【2022-陈蕴侬-如何prompt预训练大模型】 <a href="https://www.bilibili.com/video/BV1sh41137Xd/?share_source=copy_web&vd_source=b031a2a5b0a629cd338ab1f2c16ed732" target="_blank" rel="noopener">https://www.bilibili.com/video/BV1sh41137Xd/?share_source=copy_web&vd_source=b031a2a5b0a629cd338ab1f2c16ed732</a></p>]]></content>
<summary type="html">
<p>近年来,在NLP领域热度最高的技术莫过于prompt engineering ,想了解一个方向最快速的方法就是看有关这个方向的survey的paper。本文的内容主要参考CMU刘鹏飞的这篇论文:<a href="https://arxiv.org/pdf/2107.13586.pdf" target="_blank" rel="noopener">Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in Natural Language Processing</a>。这篇survey对于比较清晰的介绍了当前NLP中的范式发展,以及prompt的一些基础知识和prompt的设计方法。</p>
<p><img src="/img/prompt.jpg" alt></p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="prompt" scheme="https://blog.nicehuster.cn/tags/prompt/"/>
</entry>
<entry>
<title>参数高效微调方法-LoRA</title>
<link href="https://blog.nicehuster.cn/2023/03/25/LoRA/"/>
<id>https://blog.nicehuster.cn/2023/03/25/LoRA/</id>
<published>2023-03-25T11:13:39.000Z</published>
<updated>2023-03-30T07:37:02.893Z</updated>
<content type="html"><![CDATA[<p>和上一篇文章一样,本文依旧是介绍参数高效微调(Parameter-Efficient Fine-Tuning,PEFT) 方法:LoRA: Low-Rank Adaptation of Large Language Models,这篇论文是发表在ICLR2022上一篇论文,微软的工作,用于解决大模型finetune的问题。下面详细介绍LoRA工作原理。</p><a id="more"></a><h4 id="LoRA"><a href="#LoRA" class="headerlink" title="LoRA"></a>LoRA</h4><blockquote><p>title:<a href="https://arxiv.org/pdf/2106.09685.pdf" target="_blank" rel="noopener">LoRA: Low-Rank Adaptation of Large Language Models</a></p><p>author:Edward J. Hu, Yelong Shen, Phillip Wallis</p><p>code:<a href="https://github.com/microsoft/LoRA" target="_blank" rel="noopener">https://github.com/microsoft/LoRA</a></p></blockquote><p>现有的解决大模型finetune的方法有很多,比如部分fine-tune、adapter以及prompting等。但这些方法大多存在如下问题:</p><blockquote><ol><li>Adapter 引入额外的inference latency(增加了层数);</li><li>prefix-tuning比较难于训练;</li><li>模型性能不如全参数fine-tuning;</li></ol></blockquote><h5 id="Adapter-Layers-Introduce-Inference-Latency"><a href="#Adapter-Layers-Introduce-Inference-Latency" class="headerlink" title="Adapter Layers Introduce Inference Latency"></a>Adapter Layers Introduce Inference Latency</h5><p>显然,增加模型层数会增加inference的时长:</p><blockquote><p>While one can reduce the overall latency by pruning layers or exploiting multi-task settings , there is no direct ways to bypass the extra compute in adapter layers;</p></blockquote><p><img src="/img/infer-latency-peft.png" alt></p><p>从上图可以看出,对于线上batch size为1,sequence length比较短的情况,inference latency的变化比例会更明显。</p><h5 id="Directly-Optimizing-the-Prompt-is-Hard"><a href="#Directly-Optimizing-the-Prompt-is-Hard" class="headerlink" title="Directly Optimizing the Prompt is Hard"></a>Directly Optimizing the Prompt is Hard</h5><p>与Prefix-Tuning的难于训练相比,LoRA则更容易训练:</p><blockquote><p>We observe that prefix tuning is difficult to optimize and that its performance changes non-monotonically in trainable parameters, confirming similar observations in the original paper</p></blockquote><h5 id="模型性能不如Full-fine-tuning"><a href="#模型性能不如Full-fine-tuning" class="headerlink" title="模型性能不如Full fine-tuning"></a>模型性能不如Full fine-tuning</h5><p>预留一些sequence做adaption会让处理下游任务的可用sequence长度变少,一定程度上会影响模型性能:</p><blockquote><p>More fundamentally, reserving a part of the sequence length for adaptation necessarily reduces the sequence length available to process a downstream task, which we suspect makes tuning the prompt less performant compared to other methods.</p></blockquote><h5 id="LoRA-1"><a href="#LoRA-1" class="headerlink" title="LoRA"></a>LoRA</h5><p>先来看下LoRA的motivation:</p><blockquote><p>A neural network contains many dense layers which perform matrix multiplication. The weight matrices in these layers typically have full-rank. When adapting to a specific task, the pre-trained language models have a low “instrisic dimension” and can still learn efficiently despite a random projection to a smaller subspace.</p></blockquote><p>虽然,预训练的大模型有着较多参数,但是应用于下游任务时,其实模型主要依赖low intrinsic dimension,那adaption应该也依赖于此,所以提出了Low-Rank Adaptation (LoRA)。</p><p><img src="/img/lora-arch.png" alt></p><p>如上图所示,LoRA的思想很简单,在原始PLM旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的 <code>intrinsic rank</code> 。训练的时候固定PLM的参数,只训练降维矩阵A与升维矩阵B。而模型的输入输出维度不变,输出时将BA与PLM的参数叠加。用随机高斯分布初始化A,用0矩阵初始化B,保证开始训练时,此旁路矩阵依然是0矩阵。</p><p>具体来说,假设预训练的参数矩阵为$W_0$,它的更新可表示为:</p><script type="math/tex; mode=display">W_0+\Delta W=W_0+B A \text {, where } B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k} \text {, and the rank } r \ll \min (d, k) \text {. }</script><p>有点类似于残差连接,使用旁路的更新来模型full fine-tuning。并且full fine-tuning可以看作时LoRA的特例。</p><p>LoRA与transformer的结合也很简单,在transformer中,在self-attention中有四个权重矩阵$W_q,W_k,W_v,W_o$以及俩个MLP权重。作者仅在self-attention的计算中应用LoRA,而不动MLP模块。对于加在哪个权重参数上,作者做了一系列ablation study,如下表所示:</p><p><img src="/img/diff-type-lora.png" alt></p><p>当部署在生产中时,我们可以显式计算和存储 W = W0 + BA 并像往常一样执行推理,几乎未引入额外的inference latency。</p><p>通过实验也发现,众多数据集上LoRA在只训练极少量参数的前提下,达到了匹配full fine-tuning,是一种高效的参数更新方法。相比Adapter, BitFit,LoRA在较少训练参数时就能保证比较稳定的效果。</p><p><img src="/img/performance-of-lora-in-gpt3.png" alt></p>]]></content>
<summary type="html">
<p>和上一篇文章一样,本文依旧是介绍参数高效微调(Parameter-Efficient Fine-Tuning,PEFT) 方法:LoRA: Low-Rank Adaptation of Large Language Models,这篇论文是发表在ICLR2022上一篇论文,微软的工作,用于解决大模型finetune的问题。下面详细介绍LoRA工作原理。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="PEFT" scheme="https://blog.nicehuster.cn/tags/PEFT/"/>
</entry>
<entry>
<title>参数高效微调方法-Adapter</title>
<link href="https://blog.nicehuster.cn/2023/03/20/peft-adapter/"/>
<id>https://blog.nicehuster.cn/2023/03/20/peft-adapter/</id>
<published>2023-03-20T11:13:39.000Z</published>
<updated>2023-03-30T07:20:59.945Z</updated>
<content type="html"><![CDATA[<p>在NLP任务中,在大模型上进行Fine-tuning是一种有效地迁移学习方法。尤其是,BERT、RoBERTa等模型的提出为NLP的下游任务的解决提供了极大的便利。但在大模型上对下游任务进行fine-tuning时,大模型参数动辄数十亿,存储和训练这种大模型是十分昂贵且耗时的。而且需要庞大计算资源。参数高效微调(Parameter-Efficient Fine-Tuning,PEFT) 方法旨在解决这两个问题,PEFT 方法仅微调少量 (额外) 模型参数,同时冻结预训练大模型 的大部分参数,从而大大降低了计算和存储成本。这也克服了灾难性遗忘 的问题, PEFT 方法在小数据集上也可以取得和全参数fine-tune一样的效果。下文详细介绍一下 参数高效微调方法中的Adapter以及Adapter Fusion。</p><a id="more"></a><h4 id="Adapter"><a href="#Adapter" class="headerlink" title="Adapter"></a>Adapter</h4><blockquote><p>Title: <a href="https://arxiv.org/pdf/1902.00751.pdf" target="_blank" rel="noopener">Parameter-Efficient Transfer Learning for NLP</a></p><p>Author: Neil Houlsby(Google Research).etc</p><p>Code:<a href="https://github.com/google-research/adapter-bert.git" target="_blank" rel="noopener">https://github.com/google-research/adapter-bert.git</a></p></blockquote><p>这篇论文是发表在ICML2019,改论文中提出了Adapter,通过在大模型中插入adapter module,训练时只需训练adapter module,相比于全参数fine-tuning,Adapter只需要训练较少参数即可取得全参数fine-tuning相当的结果。先直观感受一下,Adapter和fine-tune的指标比较,如下图所示:</p><p><img src="/img/adapter-acc-num-curve.png" alt></p><p>下图展示的是adapter module 结构以及插入transformer layer后的结构。</p><p><img src="/img/adapter-arch.png" alt></p><p>在在Adapter内部,它的输入h通过矩阵乘法Wdown,先将特征维度缩小,然后通过一个非线形层$f$,再通过矩阵乘法Wup将特征维度放大到跟adapter输入一样的尺寸,最后通过一个跨层连接,将adapter的输入跟上述结果加到一起作为最终adapter的输出,即下图形式。</p><script type="math/tex; mode=display">\boldsymbol{h} \leftarrow \boldsymbol{h}+f\left(\boldsymbol{h} \boldsymbol{W}_{\text {down }}\right) \boldsymbol{W}_{\text {up }} .</script><p>至于adapter引进的模型参数,假设adapter的输入的特征维度是d,而中间的特征维度是m,那么新增的模型参数有:down-project的参数Wdown 为dxm+m,Wup的参数mxd+d,总共2md+m+d,由于m远小于d,所以真实情况下,一般新增的模型参数都只占语言模型全部参数量的0.5%~8%。同时要注意到,针对下游任务训练需要更新的参数除了adapter引入的模型参数外,还有adapter层后面紧随着的layer normalization层参数需要更新,每个layer normalization层只有均值跟方差需要更新,所以需要更新的参数是2d。(由于插入了具体任务的adapter模块,所以输入的均值跟方差发生了变化,就需要重新训练)</p><p>通过实验,可以发现只训练少量参数的adapter方法的效果可以媲美finetune语言模型全部参数的传统做法。这也验证了adapter是一种高效的参数训练方法,可以快速将语言模型的能力迁移到下游任务中去。同时,可以看到不同数据集上adapter最佳的中间层特征维度m不尽相同。</p><p><img src="/img/adapter-glue-res.png" alt></p><p>为了进一步探究adapter的参数效率跟模型性能的关系,论文做了进一步的实验,同时比对了fine-tune的方式(只更新最后几层的参数或者只更新layer normalization的参数),从结果可以看出adapter是一种更加高效的参数更新方式,同时效果也非常可观,通过引入0.5%~5%的模型参数可以达到不落后先进模型1%的性能。</p><p><img src="/img/adapter-acc-parm-curve.png" alt></p><h4 id="Adapter-Fusion"><a href="#Adapter-Fusion" class="headerlink" title="Adapter Fusion"></a>Adapter Fusion</h4><blockquote><p>Title: <a href="https://arxiv.org/pdf/2005.00247.pdf" target="_blank" rel="noopener">AdapterFusion: Non-Destructive Task Composition for Transfer Learning</a></p><p>Author: Jonas Pfeiffer(UKP Lab).etc</p><p>Code:<a href="https://adapterhub.ml/" target="_blank" rel="noopener">https://adapterhub.ml/</a></p></blockquote><p>这是2020年5月份挂载arxiv上的一篇论文,被EACL 2021接收,这篇论文提出了一种adapter变种,Adapter Fusion,用于融合多任务信息。在了解Adapter Fusion之前,先看一下这个方法提出的任务背景。我们将C 定义为 N 个分类任务的集合,具有不同规模大小的标记数据和不同的损失函数:</p><script type="math/tex; mode=display">C={(D_1,L_1),..,(D_N,L_N)}</script><p>其中,D表示标注数据,L表示损失函数。我们的目的是能够利用上述一组 N 个任务来改进目标任务m,$C_m=(D_m,L_m)$,如下所示,期望先从N个任务中学到一个模型参数(最右边参数),然后利用该参数来学习特定任务m下的一个模型参数(最左边参数):</p><script type="math/tex; mode=display">\Theta_m \leftarrow \underset{\Theta^{\prime}}{\operatorname{argmin}} L_m\left(D_m ; \Theta^{\prime}\right)</script><p>当前主流方法在处理上述问题时,通常有两种方法:</p><blockquote><p>(1)Sequential Fine-Tuning:顺序微调,在每个任务上顺序更新模型的所有权重,在每一步,模型都使用上一步学习的参数进行初始化;</p><p>(2)Multi-Task Learning (MTL):多任务学习,所有任务都是同时训练,学习一个共享表示,使模型能够更好地泛化每个任务;</p></blockquote><p>前者方法容易发生灾难性遗忘问题(catastrophic forgetting),后者需要同时访问所有任务数据,不同数据集大小和损失函数各不相同,如何平衡具有较大挑战。为了为了解决上述问题,Adapter Fusion提出一个两阶段的学习策略,其中第一阶段是knowledge extraction stage,在不同任务下引入各自的adapter模块,用于学习特定任务的信息,而第二阶段是knowledge composition step,用于学习聚合多个任务的adapter。</p><p><img src="/img/adapter-fusion-arch.png" alt></p><p>上图展示的是AdapterFusion在transformer中的结构,其中有多个Adapter模块,以及一个Adapter Fusion模块。后者用于融合前者信息。和上一篇论文提出的结构相比,这里去除了multi-head attn后面的Adapter模块。下面详细介绍Adapter Fusion提出的两阶段的学习策略。</p><h5 id="(1)knowledge-extraction-stage"><a href="#(1)knowledge-extraction-stage" class="headerlink" title="(1)knowledge extraction stage"></a>(1)knowledge extraction stage</h5><p>对于该阶段有俩种训练方式,Single-Task Adapters (ST-A),Multi-Task Adapters (MT-A)。</p><p><strong>a. Single-Task Adapters (ST-A)</strong>:对于N个任务,模型都分别独立进行优化,各个任务之间互不干扰,互不影响。对于其中第n个任务而言,相应的目标函数如下所示:</p><script type="math/tex; mode=display">\Phi_n \leftarrow \underset{\Phi}{\operatorname{argmin}} L_n\left(D_n ; \Theta_0, \Phi\right)</script><p>其中$\Phi$ 表示 adapter的权重参数,$\Theta_0$ 表示大模型的预训练参数。</p><p><strong>b. Multi-Task Adapters (MT-A)</strong>:N个任务通过多任务学习的方式,进行联合优化,相应的目标函数如下:</p><script type="math/tex; mode=display">\Theta \leftarrow \underset{\Theta, \Phi}{\operatorname{argmin}}\left(\sum_{n=1}^N L_n\left(D_n ; \Theta_0, \Phi_n\right)\right), \boldsymbol{\Theta}=\Theta_{0 \rightarrow\{1, \ldots, N\}}, \Phi_1, \ldots, \Phi_N</script><h5 id="(2)knowledge-composition-step"><a href="#(2)knowledge-composition-step" class="headerlink" title="(2)knowledge composition step"></a>(2)knowledge composition step</h5><p>对于第二阶段,就是adapter fusion大展身手的时候了。为了避免通过引入特定任务参数而带来的灾难性遗忘问题,adapter fusion提出了一个共享多任务信息的结构。针对特定任务m,adapter fusion联合了第一阶段训练的到的N个adapter信息。固定模型的预训练参数跟N个adapter的参数,新引入adapter fusion的参数,目标函数也是学习针对特定任务m的adapter fusion的参数$\Psi_m$:</p><script type="math/tex; mode=display">\Psi_m \leftarrow \underset{\Psi}{\operatorname{argmin}} L_m\left(D_m ; \Theta, \Phi_1, \ldots, \Phi_N, \Psi\right)</script><p>Adapter fusion的具体结构就是一个attention,它的参数包括query,key, value的矩阵参数,在transformer的每一层都存在,它的query是transformer每个子模块的输出结果,它的key跟value则是N个任务的adapter的输出。通过adapter fusion,模型可以为不同的任务对应的adapter分配不同的权重,聚合N个任务的信息,从而为特定任务输出更合适的结果。</p><p><img src="/img/dapterfusion-arch.png" alt></p><p>通过实验发现,第一阶段采用ST-A+第二阶段Adapter fusion是最有效的方法,在多个数据集上的平均效果达到了最佳。关于MT-A+adapter fusion没有取得最佳的效果,在于第一阶段其实已经联合了多个任务的信息了,所以adapter fusion的作用没有那么明显,同时MT-A这种多任务联合训练的方式需要投入较多的成本,并不算一种高效的参数更新方式。</p><p><img src="/img/dapterfusion-res.png" alt></p>]]></content>
<summary type="html">
<p>在NLP任务中,在大模型上进行Fine-tuning是一种有效地迁移学习方法。尤其是,BERT、RoBERTa等模型的提出为NLP的下游任务的解决提供了极大的便利。但在大模型上对下游任务进行fine-tuning时,大模型参数动辄数十亿,存储和训练这种大模型是十分昂贵且耗时的。而且需要庞大计算资源。参数高效微调(Parameter-Efficient Fine-Tuning,PEFT) 方法旨在解决这两个问题,PEFT 方法仅微调少量 (额外) 模型参数,同时冻结预训练大模型 的大部分参数,从而大大降低了计算和存储成本。这也克服了灾难性遗忘 的问题, PEFT 方法在小数据集上也可以取得和全参数fine-tune一样的效果。下文详细介绍一下 参数高效微调方法中的Adapter以及Adapter Fusion。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="PEFT" scheme="https://blog.nicehuster.cn/tags/PEFT/"/>
</entry>
<entry>
<title>让研究人员绞尽脑汁的Transformer位置编码[转载]</title>
<link href="https://blog.nicehuster.cn/2022/09/23/position%20encoding/"/>
<id>https://blog.nicehuster.cn/2022/09/23/position encoding/</id>
<published>2022-09-23T11:13:39.000Z</published>
<updated>2022-11-07T03:04:52.877Z</updated>
<content type="html"><![CDATA[<p>不同于RNN、CNN等模型,对于Transformer模型来说,位置编码的加入是必不可少的,因为纯粹的Attention模块是无法捕捉输入顺序的,即无法区分不同位置的Token。为此我们大体有两个选择:1、想办法将位置信息融入到输入中,这构成了<strong>绝对位置编码</strong>的一般做法;2、想办法微调一下Attention结构,使得它有能力分辨不同位置的Token,这构成了<strong>相对位置编码</strong>的一般做法。</p><p>虽然说起来主要就是绝对位置编码和相对位置编码两大类,但每一类其实又能衍生出各种各样的变种,为此研究人员可算是煞费苦心、绞尽脑汁了,此外还有一些不按套路出牌的位置编码。本文就让我们来欣赏一下研究人员为了更好地表达位置信息所构建出来的“八仙过海,各显神通”般的编码方案。</p><a id="more"></a><p>详细内容,请移步至原文: <a href="https://kexue.fm/archives/8130" target="_blank" rel="noopener">让研究人员绞尽脑汁的Transformer位置编码</a></p>]]></content>
<summary type="html">
<p>不同于RNN、CNN等模型,对于Transformer模型来说,位置编码的加入是必不可少的,因为纯粹的Attention模块是无法捕捉输入顺序的,即无法区分不同位置的Token。为此我们大体有两个选择:1、想办法将位置信息融入到输入中,这构成了<strong>绝对位置编码</strong>的一般做法;2、想办法微调一下Attention结构,使得它有能力分辨不同位置的Token,这构成了<strong>相对位置编码</strong>的一般做法。</p>
<p>虽然说起来主要就是绝对位置编码和相对位置编码两大类,但每一类其实又能衍生出各种各样的变种,为此研究人员可算是煞费苦心、绞尽脑汁了,此外还有一些不按套路出牌的位置编码。本文就让我们来欣赏一下研究人员为了更好地表达位置信息所构建出来的“八仙过海,各显神通”般的编码方案。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="transformer" scheme="https://blog.nicehuster.cn/tags/transformer/"/>
<category term="位置编码" scheme="https://blog.nicehuster.cn/tags/%E4%BD%8D%E7%BD%AE%E7%BC%96%E7%A0%81/"/>
</entry>
<entry>
<title>Rethinking Positional Encoding In Language Pre-Training</title>
<link href="https://blog.nicehuster.cn/2022/09/18/tupe/"/>
<id>https://blog.nicehuster.cn/2022/09/18/tupe/</id>
<published>2022-09-18T11:13:39.000Z</published>
<updated>2022-11-07T02:48:04.647Z</updated>
<content type="html"><![CDATA[<p>本文介绍一篇有关transformer中位置编码的论文,论文发表自ICLR2021,该论文主要探究了俩个问题:1)在transformer中,处理位置编码时,直接将word embedding和position embedding相加后送入multi-head attention是否合理;2)在序列分类任务中,通常会在句首附加[CLS] token用于捕捉全局信息是否合理;针对改俩个问题对position encoding和[cls]token进行改进和修正。</p><a id="more"></a><h3 id="Motivation"><a href="#Motivation" class="headerlink" title="Motivation"></a>Motivation</h3><h4 id="Position和word关系解耦"><a href="#Position和word关系解耦" class="headerlink" title="Position和word关系解耦"></a>Position和word关系解耦</h4><p>在作者看来,word embedding和position embedding属于俩种异构信息,直接进行相加,会造成mixed correlations。以绝对位置编码为例,计算attention过程如下:</p><script type="math/tex; mode=display">\begin{aligned}\alpha_{i j}^{A b s} &=\frac{\left(\left(w_i+p_i\right) W^{Q, 1}\right)\left(\left(w_j+p_j\right) W^{K, 1}\right)^T}{\sqrt{d}} \\&=\frac{\left(w_i W^{Q, 1}\right)\left(w_j W^{K, 1}\right)^T}{\sqrt{d}}+\frac{\left(w_i W^{Q, 1}\right)\left(p_j W^{K, 1}\right)^T}{\sqrt{d}} \\&+\frac{\left(p_i W^{Q, 1}\right)\left(w_j W^{K, 1}\right)^T}{\sqrt{d}}+\frac{\left(p_i W^{Q, 1}\right)\left(p_j W^{K, 1}\right)^T}{\sqrt{d}}\end{aligned}</script><p>其中w是word embedding,p是position embedding,对attention进行分项,我们可以看到上式中包含4种不同的相关系数:<strong>word-to-word, word-to-position, position-to-word, and position-to-position</strong>。</p><p>第一项和最后一项分别描述了word embedding之间的互相关性和position embedding之间的互相关性;在这里,俩项的权重矩阵$W^Q,W^K$是共享的,在作者看来,俩种不同信息使用相同的projection是不合理;中间俩项描述的是word embedding和position embedding之间的相关性。作者分别对四项的结果进行了可视化,证明其俩种不同信息之间几乎没有任何相关性;</p><p><img src="/img/tupe-correlation.png" alt="image-20221009163220630"></p><p>针对上述问题,作者对attention的计算进行了修正:</p><script type="math/tex; mode=display">\alpha_{i j}=\frac{1}{\sqrt{2 d}}\left(x_i^l W^{Q, l}\right)\left(x_j^l W^{K, l}\right)^T+\frac{1}{\sqrt{2 d}}\left(p_i U^Q\right)\left(p_j U^K\right)^T,</script><p>作者直接移除了中间俩项,并对word embedding和position embedding之间的互相关性使用了不同的权重进行projection,其中$U^Q,U^K$在不同layers之间是权重共享的;</p><h4 id="从positions中解耦-CLS"><a href="#从positions中解耦-CLS" class="headerlink" title="从positions中解耦[CLS]"></a>从positions中解耦[CLS]</h4><p>在BERT及其变体中,[CLS]通常可以存储全局信息,并用于下游任务中。但已有研究表明,这类regular words(与natural words相对)在句子中具有很强的局部依赖性,如果像对待natural words的位置信息一样对待[CLS]的位置信息,则[CLS]很可能会倾向于只关注整个句子的前几个单词,这对于下游任务显然是有害的。针对这个问题,作者对上面$\alpha_{i j}$中的最后一项,position embedding之间的互相关性进行了修改:</p><script type="math/tex; mode=display">\operatorname{reset}_\theta(v, i, j)=\left\{\begin{array}{ll}v_{i j} & i \neq 1, j \neq 1,(\text { not related to }[\mathrm{CLS}]) \\\theta_1 & i=1,(\text { from }[\mathrm{CLS}] \text { to others }) \\\theta_2 & i \neq 1, j=1,(\text { from others to }[\mathrm{CLS}])\end{array},\right.</script><p>其中,$v_{ij}=\frac{1}{\sqrt{2 d}}\left(p_i U^Q\right)\left(p_j U^K\right)^T$,其中,$\theta_1,\theta_2$均为可学习参数,从上可以看出,针对[CLS]与其他位置的相关性计算做了专门处理,如下图所示:</p><p><img src="/img/tupe-cls-untiled.png" alt="image-20221009165034068"></p><p>相当于把[CLS]标识符对应的位置信息进行了抹除,其他位置与[CLS]的相对位置都一致;</p><h3 id="Experiments"><a href="#Experiments" class="headerlink" title="Experiments"></a>Experiments</h3><p><img src="/img/tupe-comparision.png" alt="image-20221009165415253"></p><p>修正以后收敛更快,性能更好,其中TUPE是修正方法的简称,A指绝对位置编码、R指相对位置编码。</p>]]></content>
<summary type="html">
<p>本文介绍一篇有关transformer中位置编码的论文,论文发表自ICLR2021,该论文主要探究了俩个问题:1)在transformer中,处理位置编码时,直接将word embedding和position embedding相加后送入multi-head attention是否合理;2)在序列分类任务中,通常会在句首附加[CLS] token用于捕捉全局信息是否合理;针对改俩个问题对position encoding和[cls]token进行改进和修正。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="transformer" scheme="https://blog.nicehuster.cn/tags/transformer/"/>
<category term="位置编码" scheme="https://blog.nicehuster.cn/tags/%E4%BD%8D%E7%BD%AE%E7%BC%96%E7%A0%81/"/>
</entry>
<entry>
<title>通用多模态预训练方法OFA</title>
<link href="https://blog.nicehuster.cn/2022/09/09/OFA/"/>
<id>https://blog.nicehuster.cn/2022/09/09/OFA/</id>
<published>2022-09-09T11:13:39.000Z</published>
<updated>2022-09-14T10:05:36.584Z</updated>
<content type="html"><![CDATA[<p>OFA是阿里达摩院发表在ICML2022上有关多模态预训练的工作,论文链接:<a href="https://arxiv.org/abs/2202.03052" target="_blank" rel="noopener">OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework</a>。OFA是通用多模态预训练模型,使用简单的seq2seq的学习框架统一模态(跨模态、视觉、语言等模态)和任务(如图片生成、视觉定位、图片描述、图片分类、文本生成等),OFA通过任务指令即可执行pretrain/finetune,无需引入额外设计任务layer或head,OFA在一系列多模态任务上达到了sota结果。</p><a id="more"></a><p><img src="/img/tasks supported by ofa.png" alt="image-20220914162624793"></p><h3 id="Once-For-All-OFA"><a href="#Once-For-All-OFA" class="headerlink" title="Once For All(OFA)"></a>Once For All(OFA)</h3><p><img src="/img/ofa-framework.png" alt="image-20220914163046310"></p><h4 id="I-O-amp-Architecture"><a href="#I-O-amp-Architecture" class="headerlink" title="I/O & Architecture"></a>I/O & Architecture</h4><p>多模态输入大致分为:图片、文本、位置(坐标)。OFA针对这些不同的输入形式同一成序列形式。对应处理方法如下:</p><blockquote><ul><li>图片输入:使用类似SimVLM和CoAtNet的方式,直接使用ResNet模块将图片卷积成多个patch特征;</li><li>文本输入:使用byte-pair encoding (BPE)进行subword级别分词,然后进行embed;</li><li>位置(坐标):针对目标框类型的输入,把位置编码成token,每个检测框用<x1,y1,x2,y2>四个token表示,参考pix2seq做法;</x1,y1,x2,y2></li><li>统一的词表:语言和视觉token被统一起来,包括subword、图像token和位置token。</li></ul></blockquote><h4 id="Tasks-amp-Modalities"><a href="#Tasks-amp-Modalities" class="headerlink" title="Tasks & Modalities"></a>Tasks & Modalities</h4><p>OFA将单模态/多模态的理解/生成任务统一成seq2seq任务,每个任务通过任务指令进行区分:</p><blockquote><ul><li>visual grounding (VG):“Which region does the text $x_t$ describe?” $x_t$ 目标区域的文本描述;</li><li>grounded captioning(GC):VG的反向任务,“What does the region describe? region: $<x_1,y_1,x_2,y_2>$ ”.</x_1,y_1,x_2,y_2></li><li>image-text matching (ITM):“Does the image describe $x_t$ ?” $x_t$ 是图片的描述;</li><li>image captioning (IC):“What does the image describe?”</li><li>visual question answering (VQA):使用image和question作为输入,学习正确的答案;</li><li>Detection:“What are the objects in the image?”</li><li>image infilling:“What is the image in the middle part?”</li></ul></blockquote><p>在附录中,列举了针对不同任务数据集下的指令,如下表所示:</p><p><img src="/img/ofa-instructions.png" alt="image-20220914171851379"></p><h3 id="Experiments"><a href="#Experiments" class="headerlink" title="Experiments"></a>Experiments</h3><p>OFA在image caption、VQA、visual entailment 和 referring expression comprehension 4个跨模态任务中都取得了SOTA。</p><p><img src="/img/ofa-exp-image-captions.png" alt="image-20220914173932416"></p><p>在image-to-text generation任务中,OFA 也超过了DALLE, CogView和微软的NÜWA模型。</p><p><img src="/img/ofa-exps-text2image.png" alt="image-20220914174359855"></p><h3 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h3><p>OFA在整体工作上并没有提出新的框架,也没有给人一种眼前一亮的感觉,但是就是抵不住效果好,而且将多种模态\任务都统一成seq2seq任务,做的实验非常solid,而且还开源了代码,很赞。</p>]]></content>
<summary type="html">
<p>OFA是阿里达摩院发表在ICML2022上有关多模态预训练的工作,论文链接:<a href="https://arxiv.org/abs/2202.03052" target="_blank" rel="noopener">OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework</a>。OFA是通用多模态预训练模型,使用简单的seq2seq的学习框架统一模态(跨模态、视觉、语言等模态)和任务(如图片生成、视觉定位、图片描述、图片分类、文本生成等),OFA通过任务指令即可执行pretrain/finetune,无需引入额外设计任务layer或head,OFA在一系列多模态任务上达到了sota结果。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
<entry>
<title>最简单的多模态VLP模型ViLT</title>
<link href="https://blog.nicehuster.cn/2022/09/03/ViLT/"/>
<id>https://blog.nicehuster.cn/2022/09/03/ViLT/</id>
<published>2022-09-03T11:13:39.000Z</published>
<updated>2022-09-14T10:08:37.809Z</updated>
<content type="html"><![CDATA[<p>本文介绍一篇发表在ICML2021上有关多模态预训练相关的论文,ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision,ViLT不需要使用预训练的ViT来初始化多模态交互的transformer,直接用用交互层来处理视觉特征,无需额外增加视觉encoder,比如提取region features的object detector。</p><h4 id="Taxonomy-of-Vision-and-Language-Models"><a href="#Taxonomy-of-Vision-and-Language-Models" class="headerlink" title="Taxonomy of Vision-and-Language Models"></a>Taxonomy of Vision-and-Language Models</h4><p>在这篇论文中,作者依据:1)在参数或者计算上,俩种模态是否保持平衡;2)俩种模态是否是在深度网络中进行交互;将现有的VLP模型大致划分为以下四种:</p><p><img src="/img/four-cat-vlp.png" alt="image-20220905161210413"></p><p>其中每个圆角矩形的大小表示相对计算量大小,VE、TE和MI分别是visual embedding、text embedding和modality interaction的简写。</p><a id="more"></a><h4 id="Visual-Embedding-Schema"><a href="#Visual-Embedding-Schema" class="headerlink" title="Visual Embedding Schema"></a>Visual Embedding Schema</h4><p><img src="/img/Visual Embedding Schema.png" alt="image-20220905161656563"></p><p>现有的VLP模型的text embedding大多是使用类BERT结构,但是visual embedding存在着差异。在大多数情况下,visual embedding是现有VLP模型的瓶颈。visual embedding的方法总共有三大类,其中region feature方法通常采用Faster R-CNN二阶段检测器提取region的特征,grid feature方法直接使用CNN提取grid的特征,patch projection方法采用类似ViT将输入图片切片投影提取特征。</p><h4 id="Vision-and-Language-Transformer"><a href="#Vision-and-Language-Transformer" class="headerlink" title="Vision-and-Language Transformer"></a>Vision-and-Language Transformer</h4><p><img src="/img/vilt-overview.png" alt="image-20220905162017113"></p><p>上图展示的是ViLT结构,采用的是single-stream结构,对visual和text进行concat后进行交互操作。在文本特征提取部分采用的是word embedding得到text embedding而没有使用类似bert结构,在图像特征提取部分采用的是ViT那套对图像切块然后拼成序列通过线性映射得到visual embedding,最后俩个embedding都会结合各自的position embedding和modal-type embedding输入到transformer encoder中进行交互;其中modal-type embedding用于区分text embedding和visual embedding;此外,需要注意的是,text embedding和visual embedding分别都嵌入了额外的[cls] embedding用于后续的下游任务;</p><h4 id="Pre-training-Objectives"><a href="#Pre-training-Objectives" class="headerlink" title="Pre-training Objectives"></a>Pre-training Objectives</h4><p>ViLT的预训练目标函数包括俩个:image text matching (ITM) and masked language modeling (MLM);</p><blockquote><p><strong>ImageText Matching</strong>:随机以0.5的概率将文本对应的图片替换成不同的图片,然后对文本标志位对应输出使用一个线性的ITM head将输出feature映射成一个二值logits,用来判断图像文本是否匹配。</p><p><strong>Masked Language Modeling</strong>:MLM的目标是通过文本的上下文信息去预测masked的文本tokens。随机以0.15的概率mask掉tokens;</p></blockquote><p>此外,在ViLT中还使用了whole word masking技巧;将连续的子词tokens进行mask的技巧,避免了只通过单词上下文进行预测。比如,使用bert-base-uncased预训练的tokenizer对“giraffe”进行分词,就会得到[“gi”, “##raf”, “##fe”],可以mask成[“gi”, “[MASK]”, “##fe”],模型会通过mask的上下文信息[“gi”,“##fe”]来预测mask的“##raf”,就会导致不利用图像信息。</p><h4 id="Experiments"><a href="#Experiments" class="headerlink" title="Experiments"></a>Experiments</h4><p><img src="/img/vilt-runtime.png" alt="image-20220905163545451"></p><p>如上图所示,相比于基于region feature方法和grid feature方法,在速度上分别快60倍和4倍。在下游任务上也取得不错的性能。</p><p><img src="/img/vilt-cls.png" alt="image-20220905164010026"></p><p><img src="/img/vilt-retrival.png" alt="image-20220905163936444"></p>]]></content>
<summary type="html">
<p>本文介绍一篇发表在ICML2021上有关多模态预训练相关的论文,ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision,ViLT不需要使用预训练的ViT来初始化多模态交互的transformer,直接用用交互层来处理视觉特征,无需额外增加视觉encoder,比如提取region features的object detector。</p>
<h4 id="Taxonomy-of-Vision-and-Language-Models"><a href="#Taxonomy-of-Vision-and-Language-Models" class="headerlink" title="Taxonomy of Vision-and-Language Models"></a>Taxonomy of Vision-and-Language Models</h4><p>在这篇论文中,作者依据:1)在参数或者计算上,俩种模态是否保持平衡;2)俩种模态是否是在深度网络中进行交互;将现有的VLP模型大致划分为以下四种:</p>
<p><img src="/img/four-cat-vlp.png" alt="image-20220905161210413"></p>
<p>其中每个圆角矩形的大小表示相对计算量大小,VE、TE和MI分别是visual embedding、text embedding和modality interaction的简写。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
<entry>
<title>Oscar&METER方法详解</title>
<link href="https://blog.nicehuster.cn/2022/08/29/Oscar_METER/"/>
<id>https://blog.nicehuster.cn/2022/08/29/Oscar_METER/</id>
<published>2022-08-29T11:13:39.000Z</published>
<updated>2022-09-09T08:48:21.921Z</updated>
<content type="html"><![CDATA[<p>本文要介绍的是微软的俩篇有关VLP的工作,Oscar和METER,前者是发表在CVPR2020,后者是发表在CVPR2022。论文链接如下:<a href="https://arxiv.org/pdf/2004.06165.pdf" target="_blank" rel="noopener">Oscar: Object-Semantics Aligned Pre-training for Vision-Language Tasks</a>,<a href="https://arxiv.org/pdf/2111.02387.pdf" target="_blank" rel="noopener">An Empirical Study of Training End-to-End Vision-and-Language Transformers</a>,下面大致介绍这俩篇工作的具体内容。</p><a id="more"></a><h3 id="Oscar"><a href="#Oscar" class="headerlink" title="Oscar"></a>Oscar</h3><p>这篇论文中提出了一种新的多模态预训练方法Oscar,把object用作视觉和语言语义层面上的Anchor Point,以简化图像和文本之间的语义对齐的学习任务,在多个下游任务上刷新了SOTA。</p><p><img src="/img/oscar.png" alt="oscar"></p><h4 id="Motivation"><a href="#Motivation" class="headerlink" title="Motivation"></a>Motivation</h4><p>在此之前,VLP方法都是简单粗暴地将图像区域特征和文本特征连接起来作为模型的输入以进行预训练,并不为模型提供任何线索,希望模型能利用Transformer的自我注意机制,使用蛮力来学习图像文本语义对齐方式。检测器在图像上检测的object通常会出现在对应caption text中,因此作者提出使用检测出来的物体标签对应caption中的词建立一个关联。</p><h4 id="Pipeline"><a href="#Pipeline" class="headerlink" title="Pipeline"></a>Pipeline</h4><p><img src="/img/oscar-pipeline.png" alt="oscar-pipeline"></p><p>上图展示了OSCAR的pipeline,通过将对象标签作为anchor引入,Oscar在两个方面与现有的VLP不同:</p><blockquote><ol><li>输入表示:每个image-text样本定义为一个三元组(单词序列,物体标签,区域特征)。</li><li>目标函数:作者从两个不同的角度设计目标函数: modality视角(Contrastive Loss)和dictionary视角(Masked Token Loss)。</li></ol></blockquote><p>注意,在这里object tag输入的embedding是使用同一词表得到word embedding。</p><h4 id="Experiments"><a href="#Experiments" class="headerlink" title="Experiments"></a>Experiments</h4><p><img src="/img/oscar-exp.png" alt="oscar-exp"></p><p>Oscar在六项任务上均达到了SOTA。在大多数任务上,Osacr的base model要优于以前方法的large model,其表明Oscar具有很高的参数利用效率,这是因为<strong>object tag的使用大大简化了图像和文本之间语义对齐的学习</strong>。</p><p>在这之后,原班作者在Oscar基础上针对检测模型部分提出了<a href="https://arxiv.org/abs/2101.00529" target="_blank" rel="noopener">VinVL</a>,聚焦于提升检测模型提取视觉语义特征能力。</p><h3 id="METER"><a href="#METER" class="headerlink" title="METER"></a>METER</h3><p>这是微软在CVPR2022上的有关VLP的工作。本文提出了METER,一个end2end的VLP框架,并从visual encoder、text encoder、Multimodal Fusion、结构设计以及与预训练目标函数上对VLP做了详细实验分析。METER在VQAv2上取得sota结果。</p><p><img src="/img/meter-overview.png" alt="meter-overview"></p><h4 id="Glossary-of-VLP-Models"><a href="#Glossary-of-VLP-Models" class="headerlink" title="Glossary of VLP Models"></a>Glossary of VLP Models</h4><p>以往基于OD(object detector)的VLP方法需要freeze object detector,限制了VLP模型能力,而且提取region特征时间代价较大。近期大多使用的ViT做visual encoder的VLP方法相比VinVL(OD)性能存在差距,为缩小差距本文提出METER,探索如何设计VLP模型。作者对现有VLP工作根据visual encoder、text encoder、Multimodal Fusion、Decoder以及预训练目标函数进行分类划分。</p><p><img src="/img/meter-glossary-of-vlp.png" alt="meter-glossary-of-vlp"></p><h4 id="METER-1"><a href="#METER-1" class="headerlink" title="METER"></a>METER</h4><p>总体结构:输入图片和文本对,图片经过visual encoder(CLIP-ViT,Swin,BEiT等)编码,文本经过text encoder(BERT,RoBERTa,DeBERTA等)编码,之后两者经过多模态融合模块(Merged Attention/Co-Attention)进行模态信息交互,最后经过一个可选的解码器输出结果。在论文中,作者对VLP模型各个模块都进行了分类和阐述,在实验部分进行了综合性分析并得到了每个部分最好的一个结构,从和产生METER最终结构。</p><blockquote><p>Vsual Encoder: CLIP-ViT-224/16,Swin Transformer</p><p>Text Encoder:Roberta</p><p>Multimodal Fusion: Co-attention(如下图所示)</p><p>Pre-training Objectives:MLM(masked language modeling )+ITM(image-text matching)</p></blockquote><p><img src="/img/meter-multimodel-fusion.png" alt="meter-multimodel-fusion"></p><p>上图展示的是多模态融合模块的结构,在METER中,使用的是Co-Attention,Co-Attention中包含了堆叠的6层transformer layers,每层包含了一个self-attention,一个co-attention和一个前馈网络。没有decoder和编码器参数共享。</p><h4 id="Experiments-1"><a href="#Experiments-1" class="headerlink" title="Experiments"></a>Experiments</h4><p>在预训练数据方面,作者仅使用了COCO, Conceptual Captions, SBU Captions and Visual Genome,总共4M图片数据。在多个下游任务上进行验证,其中包括VQAv2,visual reasoning(NLVR2), visual entailment(SNLI-VE)和image-text retrieval(COO, Flickr30k)。作者实验做的非常详细,推荐去看原文。</p><h4 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h4><p>(1)MERTER是Vision transformer + Text transformer结构</p><p>(2)在METER中Vision encoder用CLIP-ViT或者Swim transformer,Text encoder用Roberta,多模态融合用co-attention;</p><p>(3)在目标函数上,MLM+ITM都对VLP模型有帮助,但是MIM会带来负面影响;</p>]]></content>
<summary type="html">
<p>本文要介绍的是微软的俩篇有关VLP的工作,Oscar和METER,前者是发表在CVPR2020,后者是发表在CVPR2022。论文链接如下:<a href="https://arxiv.org/pdf/2004.06165.pdf" target="_blank" rel="noopener">Oscar: Object-Semantics Aligned Pre-training for Vision-Language Tasks</a>,<a href="https://arxiv.org/pdf/2111.02387.pdf" target="_blank" rel="noopener">An Empirical Study of Training End-to-End Vision-and-Language Transformers</a>,下面大致介绍这俩篇工作的具体内容。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
<entry>
<title>ALBEF方法详解</title>
<link href="https://blog.nicehuster.cn/2022/08/18/ALBEF/"/>
<id>https://blog.nicehuster.cn/2022/08/18/ALBEF/</id>
<published>2022-08-18T11:13:39.000Z</published>
<updated>2022-09-09T08:48:14.812Z</updated>
<content type="html"><![CDATA[<p>这篇文章介绍一篇多模态预训练相关的论文,<a href="https://arxiv.org/abs/2107.07651" target="_blank" rel="noopener">Align before Fuse: Vision and Language Representation Learning with Momentum Distillation</a>,单位是Salesforce Research,下面大致的介绍一下两篇论文的具体工作。这篇paper提出了一个新的视觉-语言表征学习框架,通过在融合之前首先对齐单模态表征来实现最佳性能。</p><a id="more"></a><p>现有的VLP方法存在如下三个限制:</p><blockquote><p>Limitation 1:以CLIP和ALIGN为代表的方法分别独立学习单模态的图像encoder和文本encoder,缺乏对图像和文本之间的复杂互动进行建模的能力,因此它们不擅长于需要细粒度图像-文本理解的任务;</p><p>Limitation 2:以UNITER为代表的方法使用多模态编码器联合学习图像与文本,然而,从区域中提取的图片特征和文本词向量是没有对齐的;</p><p>Limitation 3:现有用于预训练的数据集大多是由从网络上收集的嘈杂的图像-文本对组成。广泛使用的预训练目标,如掩码语言建模(MLM),容易对噪声文本过度拟合,这将损害表示学习。</p></blockquote><p>为了解决这些限制,我们提出了ALign BEfore Fuse(ALBEF),ALBEF在多个视觉-语言下游任务上取得了SOTA的性能,如图像-文本检索、视觉问题回答(VQA)和自然语言视觉推理(NLVR)。</p><h4 id="整体结构"><a href="#整体结构" class="headerlink" title="整体结构"></a>整体结构</h4><p><img src="/img/ALBEF.png" alt="ALBEF"></p><p>上图展示了ALBEF的整体框架结构,ALBEF包含一个image encoder(ViT-B/16),一个text encoder(BERT的前6层),以及一个multimodal encoder(BERT的后6层与额外的交叉注意力层)。我们通过共同优化以下三个目标来对ALBEF进行预训练:</p><blockquote><p>Objective 1:图像-文本对比学习应用于单模态的image encoder和text encoder。它使图像特征和文本特征相一致,同时训练单模态编码器更好地理解图像和文本的语义;</p><p>Objective 2:图像-文本匹配应用于多模态编码器,预测一对图像和文本是否匹配。我们还使用了难样本挖掘,选择具有较高相似度的样本进行学习;</p><p>Objective 3:在多模态编码器上应用掩码语言建模(MLM)进行训练;</p></blockquote><h4 id="Momentum-Distillation"><a href="#Momentum-Distillation" class="headerlink" title="Momentum Distillation"></a>Momentum Distillation</h4><p>从网络上收集的图像-文本对往往是弱相关的:文本可能包含与图像无关的词,或者图像可能包含文本中没有描述的实体。为了从嘈杂的数据中学习,我们提出了动量蒸馏法,即使用动量模型为图像-文本对比学习和掩码语言建模生成伪目标。</p><h4 id="下游任务上的应用"><a href="#下游任务上的应用" class="headerlink" title="下游任务上的应用"></a>下游任务上的应用</h4><p>ALBEF在多个下游任务上取得了最先进的性能,如下表所示。在图像-文本检索方面,ALBEF优于在更大数量级的数据集上进行预训练的方法(CLIP[2]和ALIGN[3])。在VQA、NLVR和VE方面,ALBEF优于那些使用预先训练的物体检测器、额外的物体标签或对抗性数据增强的方法。</p><p><img src="/img/ALBEF-zero-shot-i2tr.png" alt="ALBEF-zero-shot-i2tr"></p><p><img src="/img/ALBEF-VQA.png" alt="ALBEF-VQA"></p><h4 id="Visual-Grounding"><a href="#Visual-Grounding" class="headerlink" title="Visual Grounding"></a>Visual Grounding</h4><p>有意思的是,ALBEF还隐含的学习了物体、属性和关系。使用Grad-CAM对multimodal encoder的交叉注意力进行可视化,在弱监督的visual grounding任务上取得很不错的结果,如下示例:</p><p><img src="/img/ALBEF-vg-vis1.png" alt></p><p><img src="/img/ALBEF-vg-vis2.png" alt></p><p><img src="/img/ALBEF-vg-vis3.png" alt></p>]]></content>
<summary type="html">
<p>这篇文章介绍一篇多模态预训练相关的论文,<a href="https://arxiv.org/abs/2107.07651" target="_blank" rel="noopener">Align before Fuse: Vision and Language Representation Learning with Momentum Distillation</a>,单位是Salesforce Research,下面大致的介绍一下两篇论文的具体工作。这篇paper提出了一个新的视觉-语言表征学习框架,通过在融合之前首先对齐单模态表征来实现最佳性能。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
<entry>
<title>BERT原理详解与HuggingFace使用[转载]</title>
<link href="https://blog.nicehuster.cn/2022/08/04/BERT/"/>
<id>https://blog.nicehuster.cn/2022/08/04/BERT/</id>
<published>2022-08-04T11:13:39.000Z</published>
<updated>2022-08-25T07:35:29.745Z</updated>
<content type="html"><![CDATA[<p>最近在做一些图文理解相关的工作,顺带了解了一下BERT,自BERT(Bidirectional Encoder Representations from Transformer)出现后,NLP界开启了一个全新的范式。本文主要介绍BERT的原理,以及如何使用HuggingFace提供的 <code>transformers</code> 库完成基于BERT的微调任务。</p><a id="more"></a><h4 id="预训练"><a href="#预训练" class="headerlink" title="预训练"></a>预训练</h4><p>BERT在一个较大的语料上进行预训练(Pre-train)。预训练主要是在数据和算力充足的条件下,训练一个大模型,在其他任务上可以利用预训练好的模型进行微调(Fine-tune)。</p><h4 id="训练目标"><a href="#训练目标" class="headerlink" title="训练目标"></a>训练目标</h4><p>BERT使用了维基百科等语料库数据,共几十GB,这是一个庞大的语料库。对于一个GB级的语料库,雇佣人力进行标注成本极高。BERT使用了两个巧妙方法来无监督地训练模型:<strong>Masked Language Modeling</strong>和<strong>Next Sentence Prediction</strong>。这两个方法可以无需花费时间和人力标注数据,以较低成本无监督地得到训练数据。图1就是一个输入输出样例。</p><p>对于Masked Language Modeling,给定一些输入句子(图1中最下面的输入层),BERT将输入句子中的一些单词盖住(图1中Masked层),经过中间的词向量和BERT层后,BERT的目标是让模型能够预测那些刚刚被盖住的词。还记得英语考试中,我们经常遇到“完形填空”题型吗?能把完形填空做对,说明已经理解了文章背后的语言逻辑。BERT的Masked Language Modeling本质上就是在做“完形填空”:预训练时,先将一部分词随机地盖住,经过模型的拟合,如果能够很好地预测那些盖住的词,模型就学到了文本的内在逻辑。</p><p><img src="http://aixingqiu-1258949597.cos.ap-beijing.myqcloud.com/2021-12-18-pretrain.png" alt="img">图1 BERT预训练的输入和输出</p><p>除了“完形填空”,BERT还需要做Next Sentence Prediction任务:预测句子B是否为句子A的下一句。Next Sentence Prediction有点像英语考试中的“段落排序”题,只不过简化到只考虑两句话。如果模型无法正确地基于当前句子预测Next Sentence,而是生硬地把两个不相关的句子拼到一起,两个句子在语义上是毫不相关的,说明模型没有读懂文本背后的意思。</p><h4 id="词向量"><a href="#词向量" class="headerlink" title="词向量"></a>词向量</h4><p>在基于深度学习的NLP方法中,文本中的词通常都用一维向量来表示。某两个词向量的 Cosine 距离较小,说明两个词在语义上相似。</p><p>信息</p><p>词向量一般由Token转换而成。英文中,一个句子中的词由空格、句号等标点隔开,我们很容易从句子中获得词。英文的词通常有前缀、后缀、词根等,在获得英文的词后,还需要抽出词根,比如图1所展示的,将“playing”切分为“play”和“##ing”。如果不对英文词进行类似词根抽取,词表过大,不容易拟合。对于英文,“play”和“##ing”分别对应两个Token。</p><p>中文一般由多个字组成一个词,传统的中文文本任务通常使用一些分词工具,得到严格意义上的词。在原始的BERT中,对于中文,并没有使用分词工具,而是直接以字为粒度得到词向量的。所以,原始的中文BERT(bert-base-chinese)输入到BERT模型的是字向量,Token就是字。后续有专门的研究去探讨,是否应该对中文进行必要的分词,以词的形式进行切分,得到向量放入BERT模型。</p><p>为了方面说明,本文不明确区分字向量还是词向量,都统称为词向量。</p><p>我们首先需要将文本中每个Token都转换成一维词向量。假如词向量的维度为<code>hidden_size</code>,句子的Token长度为<code>seq_len</code>,或者说句子共包含<code>seq_len</code>个Token,那么上图中,输入就是<code>seq_len * hidden_size</code>。再加上<code>batch_size</code>,那么输入就是<code>batch_size * seq_len * hidden_size</code>。上图只展示了一个样本,未体现出<code>batch_size</code>,或者可以理解成<code>batch_size = 1</code>,即每次只处理一条文本。</p><p>词向量经过BERT模型一系列复杂的转换后,模型最后仍然以词向量的形式输出,用以对文本进行语义表示。输入的词向量是<code>seq_len * hidden_size</code>,句子共<code>seq_len</code>个Token,将每个Token都转换成词向量,送入BERT模型。经过BERT模型后,得到的输出仍然是<code>seq_len * hidden_size</code>维度。输出仍然是<code>seq_len</code>的长度,其中输出的<code>i</code> 个位置(0 < <code>i</code> < <code>seq_len</code>)的词向量,表示经过了拟合后的第<code>i</code>个Token的语义表示。后续可以用输出中每个位置的词向量来进行一些其他任务,比如命名实体识别等。</p><p>除了使用Masked方法故意盖住一些词外,BERT还加了一些特殊的符号:<code>[CLS]</code>和<code>[SEP]</code>。<code>[CLS]</code>用在句首,是句子序列中<code>i = 0</code>位置的Token。BERT认为输出序列的<code>i = 0</code>位置的Token对应的词向量包含了整个句子的信息,可对整个句子进行分类。<code>[SEP]</code>用在分割前后两个句子上。</p><h4 id="微调"><a href="#微调" class="headerlink" title="微调"></a>微调</h4><p>经过预训练后,得到的模型可以用来微调各类任务。</p><ul><li>单文本分类任务。刚才提到,BERT模型在文本前插入一个<code>[CLS]</code>符号,并将该符号对应的输出向量作为整篇文本的语义表示,用于文本分类,如图2所示。对于<code>[CLS]</code>符号,可以理解为:与文本中已有的其它字/词相比,这个无明显语义信息的符号会更“公平”地融合文本中各个字/词的语义信息。</li></ul><p><img src="http://aixingqiu-1258949597.cos.ap-beijing.myqcloud.com/2021-12-18-single-classification.jpeg" alt="img">图2 单文本分类</p><ul><li>语句对分类任务。语句对分类任务的实际应用场景包括:问答(判断一个问题与一个答案是否匹配)、语句匹配(两句话是否表达同一个意思)等。对于该任务,BERT模型除了添加<code>[CLS]</code>符号并将对应的输出作为文本的语义表示,输入两句话之间用<code>[SEP]</code>符号作分割。</li></ul><p><img src="http://aixingqiu-1258949597.cos.ap-beijing.myqcloud.com/2021-12-18-pair-classification.jpeg" alt="img">图3 语句对分类</p><ul><li>序列标注任务。序列标注任务的实际应用场景包括:命名实体识别、中文分词、新词发现(标注每个字是词的首字、中间字或末字)、答案抽取(答案的起止位置)等。对于该任务,BERT模型利用文本中每个Token对应的输出向量对该Token进行标注(分类),如下图所示(B(Begin)、I(Inside)、E(End)分别表示一个词的第一个字、中间字和最后一个字)。</li></ul><p><img src="http://aixingqiu-1258949597.cos.ap-beijing.myqcloud.com/2021-12-18-seq-tagging.jpeg" alt="img">图4 序列标注</p><h4 id="模型结构"><a href="#模型结构" class="headerlink" title="模型结构"></a>模型结构</h4><p>Transformer是BERT的核心模块,Attention注意力机制又是Transformer中最关键的部分。BERT用到的主要是Transformer的Encoder,没有使用Transformer Decoder。把多个Transformer Encoder组装起来,就构成了BERT。在论文中,作者分别用12个和24个Transformer Encoder组装了两套BERT模型,两套模型的参数总数分别为110M和340M。</p><p><img src="http://aixingqiu-1258949597.cos.ap-beijing.myqcloud.com/2021-12-18-transformer-encoder.jpeg" alt="img"></p><h4 id="HuggingFace-Transformers"><a href="#HuggingFace-Transformers" class="headerlink" title="HuggingFace Transformers"></a>HuggingFace Transformers</h4><p>使用BERT和其他各类Transformer模型,绕不开<a href="https://huggingface.co/" target="_blank" rel="noopener">HuggingFace</a>提供的Transformers生态。HuggingFace提供了各类BERT的API(<code>transformers</code>库)、训练好的模型(HuggingFace Hub)还有数据集(<code>datasets</code>)。最初,HuggingFace用PyTorch实现了BERT,并提供了预训练的模型,后来。越来越多的人直接使用HuggingFace提供好的模型进行微调,将自己的模型共享到HuggingFace社区。HuggingFace的社区越来越庞大,不仅覆盖了PyTorch版,还提供TensorFlow版,主流的预训练模型都会提交到HuggingFace社区,供其他人使用。</p><p>使用<code>transformers</code>库进行微调,主要包括:</p><ul><li>Tokenizer:使用提供好的Tokenizer对原始文本处理,得到Token序列;</li><li>构建模型:在提供好的模型结构上,增加下游任务所需预测接口,构建所需模型;</li><li>微调:将Token序列送入构建的模型,进行训练。</li></ul><h4 id="Tokenizer"><a href="#Tokenizer" class="headerlink" title="Tokenizer"></a>Tokenizer</h4><p>下面两行代码会创建 <code>BertTokenizer</code>,并将所需的词表加载进来。首次使用这个模型时,<code>transformers</code> 会帮我们将模型从HuggingFace Hub下载到本地。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> transformers <span class="keyword">import</span> BertTokenizer</span><br><span class="line"><span class="meta">>>> </span>tokenizer = BertTokenizer.from_pretrained(<span class="string">'bert-base-cased'</span>)</span><br></pre></td></tr></table></figure><p>用得到的<code>tokenizer</code>进行分词:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>encoded_input = tokenizer(<span class="string">"我是一句话"</span>)</span><br><span class="line"><span class="meta">>>> </span>print(encoded_input)</span><br><span class="line">{<span class="string">'input_ids'</span>: [<span class="number">101</span>, <span class="number">2769</span>, <span class="number">3221</span>, <span class="number">671</span>, <span class="number">1368</span>, <span class="number">6413</span>, <span class="number">102</span>], </span><br><span class="line"><span class="string">'token_type_ids'</span>: [<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>], </span><br><span class="line"><span class="string">'attention_mask'</span>: [<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>]}</span><br></pre></td></tr></table></figure><p>得到的一个Python <code>dict</code>。其中,<code>input_ids</code>最容易理解,它表示的是句子中的每个Token在词表中的索引数字。词表(Vocabulary)是一个Token到索引数字的映射。可以使用<code>decode()</code>方法,将索引数字转换为Token。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.decode(encoded_input[<span class="string">"input_ids"</span>])</span><br><span class="line"><span class="string">'[CLS] 我 是 一 句 话 [SEP]'</span></span><br></pre></td></tr></table></figure><p>可以看到,<code>BertTokenizer</code>在给原始文本处理时,自动给文本加上了<code>[CLS]</code>和<code>[SEP]</code>这两个符号,分别对应在词表中的索引数字为101和102。<code>decode()</code>之后,也将这两个符号反向解析出来了。</p><p><code>token_type_ids</code>主要用于句子对,比如下面的例子,两个句子通过<code>[SEP]</code>分割,0表示Token对应的<code>input_ids</code>属于第一个句子,1表示Token对应的<code>input_ids</code>属于第二个句子。不是所有的模型和场景都用得上<code>token_type_ids</code>。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>encoded_input = tokenizer(<span class="string">"您贵姓?"</span>, <span class="string">"免贵姓李"</span>)</span><br><span class="line"><span class="meta">>>> </span>print(encoded_input)</span><br><span class="line">{<span class="string">'input_ids'</span>: [<span class="number">101</span>, <span class="number">2644</span>, <span class="number">6586</span>, <span class="number">1998</span>, <span class="number">136</span>, <span class="number">102</span>, <span class="number">1048</span>, <span class="number">6586</span>, <span class="number">1998</span>, <span class="number">3330</span>, <span class="number">102</span>], </span><br><span class="line"><span class="string">'token_type_ids'</span>: [<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>], </span><br><span class="line"><span class="string">'attention_mask'</span>: [<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>]}</span><br></pre></td></tr></table></figure><p>句子通常是变长的,多个句子组成一个Batch时,<code>attention_mask</code>就起了至关重要的作用。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>batch_sentences = [<span class="string">"我是一句话"</span>, <span class="string">"我是另一句话"</span>, <span class="string">"我是最后一句话"</span>]</span><br><span class="line"><span class="meta">>>> </span>batch = tokenizer(batch_sentences, padding=<span class="keyword">True</span>, return_tensors=<span class="string">"pt"</span>)</span><br><span class="line"><span class="meta">>>> </span>print(batch)</span><br><span class="line">{<span class="string">'input_ids'</span>: </span><br><span class="line"> tensor([[ <span class="number">101</span>, <span class="number">2769</span>, <span class="number">3221</span>, <span class="number">671</span>, <span class="number">1368</span>, <span class="number">6413</span>, <span class="number">102</span>, <span class="number">0</span>, <span class="number">0</span>],</span><br><span class="line"> [ <span class="number">101</span>, <span class="number">2769</span>, <span class="number">3221</span>, <span class="number">1369</span>, <span class="number">671</span>, <span class="number">1368</span>, <span class="number">6413</span>, <span class="number">102</span>, <span class="number">0</span>],</span><br><span class="line"> [ <span class="number">101</span>, <span class="number">2769</span>, <span class="number">3221</span>, <span class="number">3297</span>, <span class="number">1400</span>, <span class="number">671</span>, <span class="number">1368</span>, <span class="number">6413</span>, <span class="number">102</span>]]), </span><br><span class="line"> <span class="string">'token_type_ids'</span>: </span><br><span class="line"> tensor([[<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>]]), </span><br><span class="line"> <span class="string">'attention_mask'</span>: </span><br><span class="line"> tensor([[<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">0</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>]])}</span><br></pre></td></tr></table></figure><p>对于这种<code>batch_size = 3</code>的场景,不同句子的长度是不同的,<code>padding=True</code>表示短句子的结尾会被填充<code>[PAD]</code>符号,<code>return_tensors="pt"</code>表示返回PyTorch格式的<code>Tensor</code>。<code>attention_mask</code>告诉模型,哪些Token需要被模型关注而加入到模型训练中,哪些Token是被填充进去的无意义的符号,模型无需关注。</p><h4 id="Model"><a href="#Model" class="headerlink" title="Model"></a>Model</h4><p>下面两行代码会创建<code>BertModel</code>,并将所需的模型参数加载进来。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> transformers <span class="keyword">import</span> BertModel</span><br><span class="line"><span class="meta">>>> </span>model = BertModel.from_pretrained(<span class="string">"bert-base-chinese"</span>)</span><br></pre></td></tr></table></figure><p><code>BertModel</code>是一个PyTorch中用来包裹网络结构的<code>torch.nn.Module</code>,<code>BertModel</code>里有<code>forward()</code>方法,<code>forward()</code>方法中实现了将Token转化为词向量,再将词向量进行多层的Transformer Encoder的复杂变换。<code>forward()</code>方法的入参有<code>input_ids</code>、<code>attention_mask</code>、<code>token_type_ids</code>等等,这些参数基本上是刚才Tokenizer部分的输出。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>bert_output = model(input_ids=batch[<span class="string">'input_ids'</span>])</span><br></pre></td></tr></table></figure><p><code>forward()</code>方法返回模型预测的结果,返回结果是一个<code>tuple(torch.FloatTensor)</code>,即多个<code>Tensor</code>组成的<code>tuple</code>。<code>tuple</code>默认返回两个重要的<code>Tensor</code>:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>len(bert_output)</span><br><span class="line"><span class="number">2</span></span><br></pre></td></tr></table></figure><ul><li><strong>last_hidden_state</strong>:输出序列每个位置的语义向量,形状为:(batch_size, sequence_length, hidden_size)。</li><li><strong>pooler_output</strong>:<code>[CLS]</code>符号对应的语义向量,经过了全连接层和tanh激活;该向量可用于下游分类任务。</li></ul><h4 id="下游任务"><a href="#下游任务" class="headerlink" title="下游任务"></a>下游任务</h4><p>BERT可以进行很多下游任务,<code>transformers</code>库中实现了一些下游任务,我们也可以参考<code>transformers</code>中的实现,来做自己想做的任务。比如单文本分类,<code>transformers</code>库提供了<code>BertForSequenceClassification</code>类。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">BertForSequenceClassification</span><span class="params">(BertPreTrainedModel)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, config)</span>:</span></span><br><span class="line"> super().__init__(config)</span><br><span class="line"> self.num_labels = config.num_labels</span><br><span class="line"> self.config = config</span><br><span class="line"></span><br><span class="line"> self.bert = BertModel(config)</span><br><span class="line"> classifier_dropout = ...</span><br><span class="line"> self.dropout = nn.Dropout(classifier_dropout)</span><br><span class="line"> self.classifier = nn.Linear(config.hidden_size, config.num_labels)</span><br><span class="line"></span><br><span class="line"> ...</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(</span></span></span><br><span class="line"><span class="function"><span class="params"> ...</span></span></span><br><span class="line"><span class="function"><span class="params"> )</span>:</span></span><br><span class="line"> ...</span><br><span class="line"></span><br><span class="line"> outputs = self.bert(...)</span><br><span class="line"> pooled_output = outputs[<span class="number">1</span>]</span><br><span class="line"> pooled_output = self.dropout(pooled_output)</span><br><span class="line"> logits = self.classifier(pooled_output)</span><br><span class="line"></span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><p>在这段代码中,<code>BertForSequenceClassification</code>在<code>BertModel</code>基础上,增加了<code>nn.Dropout</code>和<code>nn.Linear</code>层,在预测时,将<code>BertModel</code>的输出放入<code>nn.Linear</code>,完成一个分类任务。除了<code>BertForSequenceClassification</code>,还有<code>BertForQuestionAnswering</code>用于问答,<code>BertForTokenClassification</code>用于序列标注,比如命名实体识别。<code>transformers</code> 中的各个API还有很多其他参数设置,比如得到每一层Transformer Encoder的输出等等,可以访问他们的<a href="https://huggingface.co/docs/transformers/" target="_blank" rel="noopener">文档</a>查看使用方法。</p><p>注:以上内容均转载自 <a href="https://lulaoshi.info/machine-learning/attention/bert" target="_blank" rel="noopener">BERT原理解析及HuggingFace transformers使用入门</a>,侵权删</p>]]></content>
<summary type="html">
<p>最近在做一些图文理解相关的工作,顺带了解了一下BERT,自BERT(Bidirectional Encoder Representations from Transformer)出现后,NLP界开启了一个全新的范式。本文主要介绍BERT的原理,以及如何使用HuggingFace提供的 <code>transformers</code> 库完成基于BERT的微调任务。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="bert" scheme="https://blog.nicehuster.cn/tags/bert/"/>
</entry>
<entry>
<title>文本分词详解</title>
<link href="https://blog.nicehuster.cn/2022/08/02/tokenize/"/>
<id>https://blog.nicehuster.cn/2022/08/02/tokenize/</id>
<published>2022-08-02T11:13:39.000Z</published>
<updated>2023-03-29T02:00:00.007Z</updated>
<content type="html"><![CDATA[<p>在对文本进行处理时,我们需要进行文本预处理 ,而最重要的一步就是分词(Tokenize) 。一个完整的分词流程如下:</p><p><img src="https://img-blog.csdnimg.cn/img_convert/74651d420e612ceddb4bb29dcbac1ba5.png" alt="img"></p><p>其中,执行分词的算法模型称为分词器(Tokenizer) ,划分好的一个个词称为 Token (为啥不直接叫 Word?接着往后看),这个过程称为 Tokenization 。我们将一个个的 token(可以理解为小片段)表示向量,我们分词的目的就是尽可能的让这些向量蕴含更多有用的信息,然后把这些向量输入到算法模型中。由于一篇文本的词往往太多了,为了方便算法模型训练,我们会选取出频率 (也可能是其它的权重)最高的若干个词组成一个词表(Vocabulary) 。</p><a id="more"></a><h4 id="古典分词方法"><a href="#古典分词方法" class="headerlink" title="古典分词方法"></a>古典分词方法</h4><p>分词,顾名思义,就是把一句话分词一个个词,这还不简单?直接把词与词直接加一个空格不就行了?那如果真这么简单我们也不用讨论了,还有什么办法呢,再想一想?或许还能<strong>按标点符号分词</strong> ,或者按<strong>语法规则分词</strong> 。</p><p><img src="https://img-blog.csdnimg.cn/img_convert/bf955428f598919108bcc612e36b8b31.png" alt="img"></p><p>面提到的这些方法,统称为<strong>古典分词方法</strong> ,区别不是很大。一个句子,使用不同的规则,将有许多种不同的分词结果。古典分词方法的缺点非常明显:</p><blockquote><ul><li>对于<strong>未在词表中出现的词(Out Of Vocabulary, OOV</strong> ),模型将无法处理(未知符号标记为 <code>[UNK]</code>)。</li><li>词表中的低频词/稀疏词在模型训无法得到训练(因为词表大小有限,太大的话会影响效率);</li><li>很多语言难以用空格进行分词,例如英语单词的多形态,“look”衍生出的”looks”, “looking”, “looked”,其实都是一个意思,但是在词表中却被当作不同的词处理,模型也无法通过 old, older, oldest 之间的关系学到 smart, smarter, smartest 之间的关系。这一方面增加了训练冗余,另一方面也造成了大词汇量问题。</li></ul></blockquote><h4 id="字符级分词方法"><a href="#字符级分词方法" class="headerlink" title="字符级分词方法"></a>字符级分词方法</h4><p>这种方法称为 Character embedding,是一种更为极端的分词方法,直接把一个词分成一个一个的字母和特殊符号。虽然能解决 OOV 问题,也避免了大词汇量问题,但缺点也太明显了,粒度太细,训练花费的成本太高,但这种思想或许我们后面会用到。</p><p><img src="https://img-blog.csdnimg.cn/img_convert/d4a136846907af983e6632ca17207033.png" alt="img"></p><h4 id="基于子词的分词方法"><a href="#基于子词的分词方法" class="headerlink" title="基于子词的分词方法"></a>基于子词的分词方法</h4><p>基于子词的分词方法(Subword Tokenization),简称为 Subword 算法,意思就是把一个词切成更小的一块一块的子词。如果我们能使用将一个 token 分成多个 subtokens,上面的问题就能很好的解决。这种方法的目的是通过一个有限的词表*来解决所有单词的分词问题,同时尽可能将结果中 token 的数目降到最低。例如,可以用更小的词片段来组成更大的词,例如:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">“unfortunately ” = “un ” + “<span class="keyword">for</span> ” + “tun ” + “ate ” + “ly ”。</span><br></pre></td></tr></table></figure><p>可以看到,有点类似英语中的词根词缀拼词法,其中的这些小片段又可以用来构造其他词。可见这样做,既可以降低词表的大小,同时对相近词也能更好地处理。Subword 与传统分词方法的比较:</p><blockquote><ul><li>传统词表示方法无法很好的处理未知或罕见的词汇(OOV 问题);</li><li>传统词 tokenization 方法不利于模型学习词缀之间的关系,例如模型学到的“old”, “older”, and “oldest”之间的关系无法泛化到“smart”, “smarter”, and “smartest”;</li><li>Character embedding 作为 OOV 的解决方法粒度太细;</li><li>Subword 粒度在词与字符之间,能够较好的平衡 OOV 问题;</li></ul></blockquote><p>目前有三种主流的 Subword 算法,它们分别是:Byte Pair Encoding (BPE)、WordPiece 和 Unigram Language Model。</p><h4 id="BPE"><a href="#BPE" class="headerlink" title="BPE"></a>BPE</h4><p>BPE是一种<strong>数据压缩</strong> 算法,用来在固定大小的词表中实现可变⻓度的子词。该算法简单有效,因而目前它是最流行的方法。BPE 首先将词分成单个字符,然后依次用另一个字符替换频率最高的<strong>一对字符</strong> ,直到循环次数结束。实例如下:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#Byte Pair Encoding Data Comperssion Example</span></span><br><span class="line">1. aaabdaaabac --> ZabdZabac # replace Z=aa</span><br><span class="line">2. ZabdZabac --> ZYdZYac # replace Y=ab</span><br><span class="line">3. ZYdZYac --> XdXac # replace X=ZY</span><br><span class="line"><span class="number">4.</span> XdXac <span class="comment"># Final Compressed string</span></span><br></pre></td></tr></table></figure><h5 id="算法流程"><a href="#算法流程" class="headerlink" title="算法流程"></a>算法流程</h5><p>接下来详细介绍 BPE 在分词中的算法过程:</p><blockquote><ul><li>准备语料库,确定期望的 subword 词表大小等参数;</li><li>通常在每个单词末尾添加后缀 <code></w></code>,统计每个单词出现的频率,例如,<code>low</code> 的频率为 5,那么我们将其改写为 <code>"l o w </ w>”:5</code></li><li>将语料库中所有单词拆分为单个字符,用所有单个字符建立最初的词典,并统计每个字符的频率,本阶段的 subword 的粒度是字符;</li><li><strong>挑出频次最高的符号对</strong> ,比如说 <code>t</code> 和 <code>h</code> 组成的 <code>th</code>,将新字符加入词表,然后将语料中所有该字符对融合(merge),即所有 <code>t</code> 和 <code>h</code> 都变为 <code>th</code>;</li><li>重复遍历 2 和 3 操作,直到<strong>词表中单词数达到设定量</strong> 或<strong>下一个最高频数为 1</strong> ,如果已经打到设定量,其余的词汇直接丢弃;</li></ul></blockquote><p>注:停止符 <code></w></code> 的意义在于标明 subword 是词后缀。举例来说:<code>st</code> 不加 <code></w></code> 可以出现在词首,如 <code>st ar</code>;加了 <code></w></code> 表明该子词位于词尾,如 <code>we st</w></code>,二者意义截然不同;</p><h5 id="代码实现"><a href="#代码实现" class="headerlink" title="代码实现"></a>代码实现</h5><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> re, collections</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_vocab</span><span class="params">(filename)</span>:</span></span><br><span class="line"> vocab = collections.defaultdict(int)</span><br><span class="line"> <span class="keyword">with</span> open(filename, <span class="string">'r'</span>, encoding=<span class="string">'utf-8'</span>) <span class="keyword">as</span> fhand:</span><br><span class="line"> <span class="keyword">for</span> line <span class="keyword">in</span> fhand:</span><br><span class="line"> words = line.strip().split()</span><br><span class="line"> <span class="keyword">for</span> word <span class="keyword">in</span> words:</span><br><span class="line"> vocab[<span class="string">' '</span>.join(list(word)) + <span class="string">' </w>'</span>] += <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> vocab</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_stats</span><span class="params">(vocab)</span>:</span></span><br><span class="line"> pairs = collections.defaultdict(int)</span><br><span class="line"> <span class="keyword">for</span> word, freq <span class="keyword">in</span> vocab.items():</span><br><span class="line"> symbols = word.split()</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> range(len(symbols)<span class="number">-1</span>):</span><br><span class="line"> pairs[symbols[i],symbols[i+<span class="number">1</span>]] += freq</span><br><span class="line"> <span class="keyword">return</span> pairs</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">merge_vocab</span><span class="params">(pair, v_in)</span>:</span></span><br><span class="line"> v_out = {}</span><br><span class="line"> bigram = re.escape(<span class="string">' '</span>.join(pair))</span><br><span class="line"> p = re.compile(<span class="string">r'(?<!\S)'</span> + bigram + <span class="string">r'(?!\S)'</span>)</span><br><span class="line"> <span class="keyword">for</span> word <span class="keyword">in</span> v_in:</span><br><span class="line"> w_out = p.sub(<span class="string">''</span>.join(pair), word)</span><br><span class="line"> v_out[w_out] = v_in[word]</span><br><span class="line"> <span class="keyword">return</span> v_out</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_tokens</span><span class="params">(vocab)</span>:</span></span><br><span class="line"> tokens = collections.defaultdict(int)</span><br><span class="line"> <span class="keyword">for</span> word, freq <span class="keyword">in</span> vocab.items():</span><br><span class="line"> word_tokens = word.split()</span><br><span class="line"> <span class="keyword">for</span> token <span class="keyword">in</span> word_tokens:</span><br><span class="line"> tokens[token] += freq</span><br><span class="line"> <span class="keyword">return</span> tokens</span><br><span class="line"></span><br><span class="line"><span class="comment"># vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Get free book from Gutenberg</span></span><br><span class="line"><span class="comment"># wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt</span></span><br><span class="line">vocab = get_vocab(<span class="string">'pg16457.txt'</span>)</span><br><span class="line"></span><br><span class="line">print(<span class="string">'=========='</span>)</span><br><span class="line">print(<span class="string">'Tokens Before BPE'</span>)</span><br><span class="line">tokens = get_tokens(vocab)</span><br><span class="line">print(<span class="string">'Tokens: {}'</span>.format(tokens))</span><br><span class="line">print(<span class="string">'Number of tokens: {}'</span>.format(len(tokens)))</span><br><span class="line">print(<span class="string">'=========='</span>)</span><br><span class="line"></span><br><span class="line">num_merges = <span class="number">1000</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(num_merges):</span><br><span class="line"> pairs = get_stats(vocab)</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> pairs:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> best = max(pairs, key=pairs.get)</span><br><span class="line"> vocab = merge_vocab(best, vocab)</span><br><span class="line"> print(<span class="string">'Iter: {}'</span>.format(i))</span><br><span class="line"> print(<span class="string">'Best pair: {}'</span>.format(best))</span><br><span class="line"> tokens = get_tokens(vocab)</span><br><span class="line"> print(<span class="string">'Tokens: {}'</span>.format(tokens))</span><br><span class="line"> print(<span class="string">'Number of tokens: {}'</span>.format(len(tokens)))</span><br><span class="line"> print(<span class="string">'=========='</span>)</span><br></pre></td></tr></table></figure><p>输出如下:</p><figure class="highlight groovy"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">==========</span><br><span class="line">Tokens Before BPE</span><br><span class="line"><span class="string">Tokens:</span> defaultdict(<<span class="class"><span class="keyword">class</span> '<span class="title">int</span>'>, {</span><span class="string">'\ufeff'</span>: <span class="number">1</span>, <span class="string">'T'</span>: <span class="number">1610</span>, <span class="string">'h'</span>: <span class="number">26094</span>, <span class="string">'e'</span>: <span class="number">59152</span>, <span class="string">'</w>'</span>: <span class="number">101830</span>, <span class="string">'P'</span>: <span class="number">780</span>, <span class="string">'r'</span>: <span class="number">29540</span>, <span class="string">'o'</span>: <span class="number">34983</span>, <span class="string">'j'</span>: <span class="number">857</span>, <span class="string">'c'</span>: <span class="number">13891</span>, <span class="string">'t'</span>: <span class="number">44258</span>, <span class="string">'G'</span>: <span class="number">300</span>, <span class="string">'u'</span>: <span class="number">13731</span>, <span class="string">'n'</span>: <span class="number">32499</span>, <span class="string">'b'</span>: <span class="number">7428</span>, <span class="string">'g'</span>: <span class="number">8744</span>, <span class="string">'E'</span>: <span class="number">901</span>, <span class="string">'B'</span>: <span class="number">1163</span>, <span class="string">'k'</span>: <span class="number">2726</span>, <span class="string">'f'</span>: <span class="number">10469</span>, <span class="string">'A'</span>: <span class="number">1381</span>, <span class="string">'l'</span>: <span class="number">20632</span>, <span class="string">'d'</span>: <span class="number">17576</span>, <span class="string">'M'</span>: <span class="number">1206</span>, <span class="string">','</span>: <span class="number">8068</span>, <span class="string">'y'</span>: <span class="number">8812</span>, <span class="string">'J'</span>: <span class="number">80</span>, <span class="string">'s'</span>: <span class="number">28320</span>, <span class="string">'V'</span>: <span class="number">104</span>, <span class="string">'i'</span>: <span class="number">31435</span>, <span class="string">'a'</span>: <span class="number">36692</span>, <span class="string">'w'</span>: <span class="number">8133</span>, <span class="string">'m'</span>: <span class="number">9812</span>, <span class="string">'v'</span>: <span class="number">4880</span>, <span class="string">'.'</span>: <span class="number">4055</span>, <span class="string">'Y'</span>: <span class="number">250</span>, <span class="string">'p'</span>: <span class="number">8040</span>, <span class="string">'-'</span>: <span class="number">1128</span>, <span class="string">'L'</span>: <span class="number">429</span>, <span class="string">':'</span>: <span class="number">209</span>, <span class="string">'R'</span>: <span class="number">369</span>, <span class="string">'D'</span>: <span class="number">327</span>, <span class="string">'6'</span>: <span class="number">77</span>, <span class="string">'2'</span>: <span class="number">158</span>, <span class="string">'0'</span>: <span class="number">401</span>, <span class="string">'5'</span>: <span class="number">131</span>, <span class="string">'['</span>: <span class="number">32</span>, <span class="string">'#'</span>: <span class="number">1</span>, <span class="string">'1'</span>: <span class="number">295</span>, <span class="string">'4'</span>: <span class="number">104</span>, <span class="string">'7'</span>: <span class="number">65</span>, <span class="string">']'</span>: <span class="number">32</span>, <span class="string">'*'</span>: <span class="number">44</span>, <span class="string">'S'</span>: <span class="number">860</span>, <span class="string">'O'</span>: <span class="number">510</span>, <span class="string">'F'</span>: <span class="number">422</span>, <span class="string">'H'</span>: <span class="number">689</span>, <span class="string">'I'</span>: <span class="number">1432</span>, <span class="string">'C'</span>: <span class="number">863</span>, <span class="string">'U'</span>: <span class="number">170</span>, <span class="string">'N'</span>: <span class="number">796</span>, <span class="string">'K'</span>: <span class="number">42</span>, <span class="string">'/'</span>: <span class="number">52</span>, <span class="string">'"'</span>: <span class="number">4086</span>, <span class="string">'!'</span>: <span class="number">1214</span>, <span class="string">'W'</span>: <span class="number">579</span>, <span class="string">'3'</span>: <span class="number">105</span>, <span class="string">"'"</span>: <span class="number">1243</span>, <span class="string">'Q'</span>: <span class="number">33</span>, <span class="string">'X'</span>: <span class="number">49</span>, <span class="string">'Z'</span>: <span class="number">10</span>, <span class="string">'?'</span>: <span class="number">651</span>, <span class="string">'8'</span>: <span class="number">75</span>, <span class="string">'9'</span>: <span class="number">38</span>, <span class="string">'_'</span>: <span class="number">1426</span>, <span class="string">'à'</span>: <span class="number">3</span>, <span class="string">'x'</span>: <span class="number">937</span>, <span class="string">'z'</span>: <span class="number">365</span>, <span class="string">'°'</span>: <span class="number">41</span>, <span class="string">'q'</span>: <span class="number">575</span>, <span class="string">';'</span>: <span class="number">561</span>, <span class="string">'('</span>: <span class="number">56</span>, <span class="string">')'</span>: <span class="number">56</span>, <span class="string">'{'</span>: <span class="number">23</span>, <span class="string">'}'</span>: <span class="number">16</span>, <span class="string">'è'</span>: <span class="number">2</span>, <span class="string">'é'</span>: <span class="number">14</span>, <span class="string">'+'</span>: <span class="number">2</span>, <span class="string">'='</span>: <span class="number">3</span>, <span class="string">'ö'</span>: <span class="number">2</span>, <span class="string">'ê'</span>: <span class="number">5</span>, <span class="string">'â'</span>: <span class="number">1</span>, <span class="string">'ô'</span>: <span class="number">1</span>, <span class="string">'Æ'</span>: <span class="number">3</span>, <span class="string">'æ'</span>: <span class="number">2</span>, <span class="string">'%'</span>: <span class="number">1</span>, <span class="string">'@'</span>: <span class="number">2</span>, <span class="string">'$'</span>: <span class="number">2</span>})</span><br><span class="line">Number of <span class="string">tokens:</span> <span class="number">98</span></span><br><span class="line">==========</span><br><span class="line"><span class="string">Iter:</span> <span class="number">0</span></span><br><span class="line">Best <span class="string">pair:</span> (<span class="string">'e'</span>, <span class="string">'</w>'</span>)</span><br><span class="line"><span class="string">Tokens:</span> defaultdict(<<span class="class"><span class="keyword">class</span> '<span class="title">int</span>'>, {</span><span class="string">'\ufeff'</span>: <span class="number">1</span>, <span class="string">'T'</span>: <span class="number">1610</span>, <span class="string">'h'</span>: <span class="number">26094</span>, <span class="string">'e</w>'</span>: <span class="number">17749</span>, <span class="string">'P'</span>: <span class="number">780</span>, <span class="string">'r'</span>: <span class="number">29540</span>, <span class="string">'o'</span>: <span class="number">34983</span>, <span class="string">'j'</span>: <span class="number">857</span>, <span class="string">'e'</span>: <span class="number">41403</span>, <span class="string">'c'</span>: <span class="number">13891</span>, <span class="string">'t'</span>: <span class="number">44258</span>, <span class="string">'</w>'</span>: <span class="number">84081</span>, <span class="string">'G'</span>: <span class="number">300</span>, <span class="string">'u'</span>: <span class="number">13731</span>, <span class="string">'n'</span>: <span class="number">32499</span>, <span class="string">'b'</span>: <span class="number">7428</span>, <span class="string">'g'</span>: <span class="number">8744</span>, <span class="string">'E'</span>: <span class="number">901</span>, <span class="string">'B'</span>: <span class="number">1163</span>, <span class="string">'k'</span>: <span class="number">2726</span>, <span class="string">'f'</span>: <span class="number">10469</span>, <span class="string">'A'</span>: <span class="number">1381</span>, <span class="string">'l'</span>: <span class="number">20632</span>, <span class="string">'d'</span>: <span class="number">17576</span>, <span class="string">'M'</span>: <span class="number">1206</span>, <span class="string">','</span>: <span class="number">8068</span>, <span class="string">'y'</span>: <span class="number">8812</span>, <span class="string">'J'</span>: <span class="number">80</span>, <span class="string">'s'</span>: <span class="number">28320</span>, <span class="string">'V'</span>: <span class="number">104</span>, <span class="string">'i'</span>: <span class="number">31435</span>, <span class="string">'a'</span>: <span class="number">36692</span>, <span class="string">'w'</span>: <span class="number">8133</span>, <span class="string">'m'</span>: <span class="number">9812</span>, <span class="string">'v'</span>: <span class="number">4880</span>, <span class="string">'.'</span>: <span class="number">4055</span>, <span class="string">'Y'</span>: <span class="number">250</span>, <span class="string">'p'</span>: <span class="number">8040</span>, <span class="string">'-'</span>: <span class="number">1128</span>, <span class="string">'L'</span>: <span class="number">429</span>, <span class="string">':'</span>: <span class="number">209</span>, <span class="string">'R'</span>: <span class="number">369</span>, <span class="string">'D'</span>: <span class="number">327</span>, <span class="string">'6'</span>: <span class="number">77</span>, <span class="string">'2'</span>: <span class="number">158</span>, <span class="string">'0'</span>: <span class="number">401</span>, <span class="string">'5'</span>: <span class="number">131</span>, <span class="string">'['</span>: <span class="number">32</span>, <span class="string">'#'</span>: <span class="number">1</span>, <span class="string">'1'</span>: <span class="number">295</span>, <span class="string">'4'</span>: <span class="number">104</span>, <span class="string">'7'</span>: <span class="number">65</span>, <span class="string">']'</span>: <span class="number">32</span>, <span class="string">'*'</span>: <span class="number">44</span>, <span class="string">'S'</span>: <span class="number">860</span>, <span class="string">'O'</span>: <span class="number">510</span>, <span class="string">'F'</span>: <span class="number">422</span>, <span class="string">'H'</span>: <span class="number">689</span>, <span class="string">'I'</span>: <span class="number">1432</span>, <span class="string">'C'</span>: <span class="number">863</span>, <span class="string">'U'</span>: <span class="number">170</span>, <span class="string">'N'</span>: <span class="number">796</span>, <span class="string">'K'</span>: <span class="number">42</span>, <span class="string">'/'</span>: <span class="number">52</span>, <span class="string">'"'</span>: <span class="number">4086</span>, <span class="string">'!'</span>: <span class="number">1214</span>, <span class="string">'W'</span>: <span class="number">579</span>, <span class="string">'3'</span>: <span class="number">105</span>, <span class="string">"'"</span>: <span class="number">1243</span>, <span class="string">'Q'</span>: <span class="number">33</span>, <span class="string">'X'</span>: <span class="number">49</span>, <span class="string">'Z'</span>: <span class="number">10</span>, <span class="string">'?'</span>: <span class="number">651</span>, <span class="string">'8'</span>: <span class="number">75</span>, <span class="string">'9'</span>: <span class="number">38</span>, <span class="string">'_'</span>: <span class="number">1426</span>, <span class="string">'à'</span>: <span class="number">3</span>, <span class="string">'x'</span>: <span class="number">937</span>, <span class="string">'z'</span>: <span class="number">365</span>, <span class="string">'°'</span>: <span class="number">41</span>, <span class="string">'q'</span>: <span class="number">575</span>, <span class="string">';'</span>: <span class="number">561</span>, <span class="string">'('</span>: <span class="number">56</span>, <span class="string">')'</span>: <span class="number">56</span>, <span class="string">'{'</span>: <span class="number">23</span>, <span class="string">'}'</span>: <span class="number">16</span>, <span class="string">'è'</span>: <span class="number">2</span>, <span class="string">'é'</span>: <span class="number">14</span>, <span class="string">'+'</span>: <span class="number">2</span>, <span class="string">'='</span>: <span class="number">3</span>, <span class="string">'ö'</span>: <span class="number">2</span>, <span class="string">'ê'</span>: <span class="number">5</span>, <span class="string">'â'</span>: <span class="number">1</span>, <span class="string">'ô'</span>: <span class="number">1</span>, <span class="string">'Æ'</span>: <span class="number">3</span>, <span class="string">'æ'</span>: <span class="number">2</span>, <span class="string">'%'</span>: <span class="number">1</span>, <span class="string">'@'</span>: <span class="number">2</span>, <span class="string">'$'</span>: <span class="number">2</span>})</span><br><span class="line">Number of <span class="string">tokens:</span> <span class="number">99</span></span><br><span class="line">==========</span><br><span class="line"><span class="string">Iter:</span> <span class="number">1</span></span><br><span class="line">Best <span class="string">pair:</span> (<span class="string">'t'</span>, <span class="string">'h'</span>)</span><br><span class="line"><span class="string">Tokens:</span> defaultdict(<<span class="class"><span class="keyword">class</span> '<span class="title">int</span>'>, {</span><span class="string">'\ufeff'</span>: <span class="number">1</span>, <span class="string">'T'</span>: <span class="number">1610</span>, <span class="string">'h'</span>: <span class="number">12065</span>, <span class="string">'e</w>'</span>: <span class="number">17749</span>, <span class="string">'P'</span>: <span class="number">780</span>, <span class="string">'r'</span>: <span class="number">29540</span>, <span class="string">'o'</span>: <span class="number">34983</span>, <span class="string">'j'</span>: <span class="number">857</span>, <span class="string">'e'</span>: <span class="number">41403</span>, <span class="string">'c'</span>: <span class="number">13891</span>, <span class="string">'t'</span>: <span class="number">30229</span>, <span class="string">'</w>'</span>: <span class="number">84081</span>, <span class="string">'G'</span>: <span class="number">300</span>, <span class="string">'u'</span>: <span class="number">13731</span>, <span class="string">'n'</span>: <span class="number">32499</span>, <span class="string">'b'</span>: <span class="number">7428</span>, <span class="string">'g'</span>: <span class="number">8744</span>, <span class="string">'E'</span>: <span class="number">901</span>, <span class="string">'B'</span>: <span class="number">1163</span>, <span class="string">'k'</span>: <span class="number">2726</span>, <span class="string">'f'</span>: <span class="number">10469</span>, <span class="string">'A'</span>: <span class="number">1381</span>, <span class="string">'l'</span>: <span class="number">20632</span>, <span class="string">'d'</span>: <span class="number">17576</span>, <span class="string">'th'</span>: <span class="number">14029</span>, <span class="string">'M'</span>: <span class="number">1206</span>, <span class="string">','</span>: <span class="number">8068</span>, <span class="string">'y'</span>: <span class="number">8812</span>, <span class="string">'J'</span>: <span class="number">80</span>, <span class="string">'s'</span>: <span class="number">28320</span>, <span class="string">'V'</span>: <span class="number">104</span>, <span class="string">'i'</span>: <span class="number">31435</span>, <span class="string">'a'</span>: <span class="number">36692</span>, <span class="string">'w'</span>: <span class="number">8133</span>, <span class="string">'m'</span>: <span class="number">9812</span>, <span class="string">'v'</span>: <span class="number">4880</span>, <span class="string">'.'</span>: <span class="number">4055</span>, <span class="string">'Y'</span>: <span class="number">250</span>, <span class="string">'p'</span>: <span class="number">8040</span>, <span class="string">'-'</span>: <span class="number">1128</span>, <span class="string">'L'</span>: <span class="number">429</span>, <span class="string">':'</span>: <span class="number">209</span>, <span class="string">'R'</span>: <span class="number">369</span>, <span class="string">'D'</span>: <span class="number">327</span>, <span class="string">'6'</span>: <span class="number">77</span>, <span class="string">'2'</span>: <span class="number">158</span>, <span class="string">'0'</span>: <span class="number">401</span>, <span class="string">'5'</span>: <span class="number">131</span>, <span class="string">'['</span>: <span class="number">32</span>, <span class="string">'#'</span>: <span class="number">1</span>, <span class="string">'1'</span>: <span class="number">295</span>, <span class="string">'4'</span>: <span class="number">104</span>, <span class="string">'7'</span>: <span class="number">65</span>, <span class="string">']'</span>: <span class="number">32</span>, <span class="string">'*'</span>: <span class="number">44</span>, <span class="string">'S'</span>: <span class="number">860</span>, <span class="string">'O'</span>: <span class="number">510</span>, <span class="string">'F'</span>: <span class="number">422</span>, <span class="string">'H'</span>: <span class="number">689</span>, <span class="string">'I'</span>: <span class="number">1432</span>, <span class="string">'C'</span>: <span class="number">863</span>, <span class="string">'U'</span>: <span class="number">170</span>, <span class="string">'N'</span>: <span class="number">796</span>, <span class="string">'K'</span>: <span class="number">42</span>, <span class="string">'/'</span>: <span class="number">52</span>, <span class="string">'"'</span>: <span class="number">4086</span>, <span class="string">'!'</span>: <span class="number">1214</span>, <span class="string">'W'</span>: <span class="number">579</span>, <span class="string">'3'</span>: <span class="number">105</span>, <span class="string">"'"</span>: <span class="number">1243</span>, <span class="string">'Q'</span>: <span class="number">33</span>, <span class="string">'X'</span>: <span class="number">49</span>, <span class="string">'Z'</span>: <span class="number">10</span>, <span class="string">'?'</span>: <span class="number">651</span>, <span class="string">'8'</span>: <span class="number">75</span>, <span class="string">'9'</span>: <span class="number">38</span>, <span class="string">'_'</span>: <span class="number">1426</span>, <span class="string">'à'</span>: <span class="number">3</span>, <span class="string">'x'</span>: <span class="number">937</span>, <span class="string">'z'</span>: <span class="number">365</span>, <span class="string">'°'</span>: <span class="number">41</span>, <span class="string">'q'</span>: <span class="number">575</span>, <span class="string">';'</span>: <span class="number">561</span>, <span class="string">'('</span>: <span class="number">56</span>, <span class="string">')'</span>: <span class="number">56</span>, <span class="string">'{'</span>: <span class="number">23</span>, <span class="string">'}'</span>: <span class="number">16</span>, <span class="string">'è'</span>: <span class="number">2</span>, <span class="string">'é'</span>: <span class="number">14</span>, <span class="string">'+'</span>: <span class="number">2</span>, <span class="string">'='</span>: <span class="number">3</span>, <span class="string">'ö'</span>: <span class="number">2</span>, <span class="string">'ê'</span>: <span class="number">5</span>, <span class="string">'â'</span>: <span class="number">1</span>, <span class="string">'ô'</span>: <span class="number">1</span>, <span class="string">'Æ'</span>: <span class="number">3</span>, <span class="string">'æ'</span>: <span class="number">2</span>, <span class="string">'%'</span>: <span class="number">1</span>, <span class="string">'@'</span>: <span class="number">2</span>, <span class="string">'$'</span>: <span class="number">2</span>})</span><br><span class="line">Number of <span class="string">tokens:</span> <span class="number">100</span></span><br><span class="line">==========</span><br><span class="line"><span class="string">Iter:</span> <span class="number">2</span></span><br><span class="line">Best <span class="string">pair:</span> (<span class="string">'t'</span>, <span class="string">'</w>'</span>)</span><br><span class="line"><span class="string">Tokens:</span> defaultdict(<<span class="class"><span class="keyword">class</span> '<span class="title">int</span>'>, {</span><span class="string">'\ufeff'</span>: <span class="number">1</span>, <span class="string">'T'</span>: <span class="number">1610</span>, <span class="string">'h'</span>: <span class="number">12065</span>, <span class="string">'e</w>'</span>: <span class="number">17749</span>, <span class="string">'P'</span>: <span class="number">780</span>, <span class="string">'r'</span>: <span class="number">29540</span>, <span class="string">'o'</span>: <span class="number">34983</span>, <span class="string">'j'</span>: <span class="number">857</span>, <span class="string">'e'</span>: <span class="number">41403</span>, <span class="string">'c'</span>: <span class="number">13891</span>, <span class="string">'t</w>'</span>: <span class="number">9271</span>, <span class="string">'G'</span>: <span class="number">300</span>, <span class="string">'u'</span>: <span class="number">13731</span>, <span class="string">'t'</span>: <span class="number">20958</span>, <span class="string">'n'</span>: <span class="number">32499</span>, <span class="string">'b'</span>: <span class="number">7428</span>, <span class="string">'g'</span>: <span class="number">8744</span>, <span class="string">'</w>'</span>: <span class="number">74810</span>, <span class="string">'E'</span>: <span class="number">901</span>, <span class="string">'B'</span>: <span class="number">1163</span>, <span class="string">'k'</span>: <span class="number">2726</span>, <span class="string">'f'</span>: <span class="number">10469</span>, <span class="string">'A'</span>: <span class="number">1381</span>, <span class="string">'l'</span>: <span class="number">20632</span>, <span class="string">'d'</span>: <span class="number">17576</span>, <span class="string">'th'</span>: <span class="number">14029</span>, <span class="string">'M'</span>: <span class="number">1206</span>, <span class="string">','</span>: <span class="number">8068</span>, <span class="string">'y'</span>: <span class="number">8812</span>, <span class="string">'J'</span>: <span class="number">80</span>, <span class="string">'s'</span>: <span class="number">28320</span>, <span class="string">'V'</span>: <span class="number">104</span>, <span class="string">'i'</span>: <span class="number">31435</span>, <span class="string">'a'</span>: <span class="number">36692</span>, <span class="string">'w'</span>: <span class="number">8133</span>, <span class="string">'m'</span>: <span class="number">9812</span>, <span class="string">'v'</span>: <span class="number">4880</span>, <span class="string">'.'</span>: <span class="number">4055</span>, <span class="string">'Y'</span>: <span class="number">250</span>, <span class="string">'p'</span>: <span class="number">8040</span>, <span class="string">'-'</span>: <span class="number">1128</span>, <span class="string">'L'</span>: <span class="number">429</span>, <span class="string">':'</span>: <span class="number">209</span>, <span class="string">'R'</span>: <span class="number">369</span>, <span class="string">'D'</span>: <span class="number">327</span>, <span class="string">'6'</span>: <span class="number">77</span>, <span class="string">'2'</span>: <span class="number">158</span>, <span class="string">'0'</span>: <span class="number">401</span>, <span class="string">'5'</span>: <span class="number">131</span>, <span class="string">'['</span>: <span class="number">32</span>, <span class="string">'#'</span>: <span class="number">1</span>, <span class="string">'1'</span>: <span class="number">295</span>, <span class="string">'4'</span>: <span class="number">104</span>, <span class="string">'7'</span>: <span class="number">65</span>, <span class="string">']'</span>: <span class="number">32</span>, <span class="string">'*'</span>: <span class="number">44</span>, <span class="string">'S'</span>: <span class="number">860</span>, <span class="string">'O'</span>: <span class="number">510</span>, <span class="string">'F'</span>: <span class="number">422</span>, <span class="string">'H'</span>: <span class="number">689</span>, <span class="string">'I'</span>: <span class="number">1432</span>, <span class="string">'C'</span>: <span class="number">863</span>, <span class="string">'U'</span>: <span class="number">170</span>, <span class="string">'N'</span>: <span class="number">796</span>, <span class="string">'K'</span>: <span class="number">42</span>, <span class="string">'/'</span>: <span class="number">52</span>, <span class="string">'"'</span>: <span class="number">4086</span>, <span class="string">'!'</span>: <span class="number">1214</span>, <span class="string">'W'</span>: <span class="number">579</span>, <span class="string">'3'</span>: <span class="number">105</span>, <span class="string">"'"</span>: <span class="number">1243</span>, <span class="string">'Q'</span>: <span class="number">33</span>, <span class="string">'X'</span>: <span class="number">49</span>, <span class="string">'Z'</span>: <span class="number">10</span>, <span class="string">'?'</span>: <span class="number">651</span>, <span class="string">'8'</span>: <span class="number">75</span>, <span class="string">'9'</span>: <span class="number">38</span>, <span class="string">'_'</span>: <span class="number">1426</span>, <span class="string">'à'</span>: <span class="number">3</span>, <span class="string">'x'</span>: <span class="number">937</span>, <span class="string">'z'</span>: <span class="number">365</span>, <span class="string">'°'</span>: <span class="number">41</span>, <span class="string">'q'</span>: <span class="number">575</span>, <span class="string">';'</span>: <span class="number">561</span>, <span class="string">'('</span>: <span class="number">56</span>, <span class="string">')'</span>: <span class="number">56</span>, <span class="string">'{'</span>: <span class="number">23</span>, <span class="string">'}'</span>: <span class="number">16</span>, <span class="string">'è'</span>: <span class="number">2</span>, <span class="string">'é'</span>: <span class="number">14</span>, <span class="string">'+'</span>: <span class="number">2</span>, <span class="string">'='</span>: <span class="number">3</span>, <span class="string">'ö'</span>: <span class="number">2</span>, <span class="string">'ê'</span>: <span class="number">5</span>, <span class="string">'â'</span>: <span class="number">1</span>, <span class="string">'ô'</span>: <span class="number">1</span>, <span class="string">'Æ'</span>: <span class="number">3</span>, <span class="string">'æ'</span>: <span class="number">2</span>, <span class="string">'%'</span>: <span class="number">1</span>, <span class="string">'@'</span>: <span class="number">2</span>, <span class="string">'$'</span>: <span class="number">2</span>})</span><br><span class="line">Number of <span class="string">tokens:</span> <span class="number">101</span></span><br><span class="line">==========</span><br></pre></td></tr></table></figure><h5 id="优点"><a href="#优点" class="headerlink" title="优点"></a>优点</h5><p>可以有效地平衡词典大小和编码后的token数量;随着合并的次数增加,词表大小通常先增加后减小。迭代次数太小,大部分还是字母,没什么意义;迭代次数多,又重新变回了原来那几个词。所以词表大小要取一个中间值。</p><p><img src="https://img-blog.csdnimg.cn/img_convert/7f31f525c34590fad917ad0ee89a320d.png" alt="img"></p><h5 id="适用范围"><a href="#适用范围" class="headerlink" title="适用范围"></a>适用范围</h5><p>BPE 一般适用在欧美语言拉丁语系中,因为欧美语言大多是字符形式,涉及前缀、后缀的单词比较多。而中文的汉字一般不用 BPE 进行编码,因为中文是字无法进行拆分。对中文的处理通常只有分词和分字两种。理论上分词效果更好,更好的区别语义。分字效率高、简洁,因为常用的字不过 3000 字,词表更加简短。</p><h4 id="WordPiece"><a href="#WordPiece" class="headerlink" title="WordPiece"></a>WordPiece</h4><p>WordPiece与BPE非常相似,也是每次从词表中选出两个子词合并成新的子词,区别在于,BPE选择频数最高的相邻子词合并,而WordPiece选择能够提升<strong>语言模型概率最大</strong>的相邻子词加入词表。假设句子$S=(t_1,t_2,…t_n)$ 是由n个字词组成,$t_i$表示字词,且假设各个字词之间是独立存在的,则句子的语言模型似然值等价于所有字词概率的乘积:</p><script type="math/tex; mode=display">logP(S)=\sum_{i=1}^{n}logP(t_i)</script><p>设把相邻位置的x和y两个子词进行合并,合并后产生的子词为z,此时句子<em>S</em>似然值的变化可表示为:</p><script type="math/tex; mode=display">\log P\left(t_z\right)-\left(\log P\left(t_x\right)+\log P\left(t_y\right)\right)=\log \left(\frac{P\left(t_z\right)}{P\left(t_x\right) P\left(t_y\right)}\right)</script><p>可以看见似然值的变化就是两个子词之间的互信息。简而言之,WordPiece每次选择合并的两个子词,他们具有最大的互信息,也就是两个子词在语言模型上具有较强的关联性,它们经常在语料中以相邻的方式同时出现。</p><h4 id="Unigram-Language-Model"><a href="#Unigram-Language-Model" class="headerlink" title="Unigram Language Model"></a>Unigram Language Model</h4><p>Unigram与BPE和WordPiece的区别在于,BPE和Worpiece算法的词表都是一点一点增加,由小到大的。而Unigram则是先初始化一个非常巨大的词表,然后根据标准不断的丢弃,知道词表大小满足限定条件。Unigram算法考虑了句子的不同分词可能,因而能够出输出带概率的子词分段。详细算法可以看原文:<a href="https://arxiv.org/pdf/1804.10959.pdf" target="_blank" rel="noopener">https://arxiv.org/pdf/1804.10959.pdf</a></p><h4 id="SentencePiece"><a href="#SentencePiece" class="headerlink" title="SentencePiece"></a>SentencePiece</h4><p>上述的所有算法都有一个前提:输入以空格来进行区分。然而并不是所有语言的词语都是使用空格来进行分割(比如中文、日文),一种比较常见的做法是使用预分词。为了更加一般化的解决这个问题,谷歌推出了开源工具包<a href="https://arxiv.org/pdf/1808.06226.pdf" target="_blank" rel="noopener">SentencePiece</a> 。SentencePiece是把一个句子看做一个整体,再拆成片段,而没有保留天然的词语的概念。一般地,它把space也当做一种特殊的字符来处理,再用BPE或者Unigram算法来构造词汇表。比如,XLNetTokenizer就采用了_来代替空格,解码的时候会再用空格替换回来。目前,Tokenizers库中,所有使用了SentencePiece的都是与Unigram算法联合使用的,比如ALBERT、XLNet、Marian和T5.</p>]]></content>
<summary type="html">
<p>在对文本进行处理时,我们需要进行文本预处理 ,而最重要的一步就是分词(Tokenize) 。一个完整的分词流程如下:</p>
<p><img src="https://img-blog.csdnimg.cn/img_convert/74651d420e612ceddb4bb29dcbac1ba5.png" alt="img"></p>
<p>其中,执行分词的算法模型称为分词器(Tokenizer) ,划分好的一个个词称为 Token (为啥不直接叫 Word?接着往后看),这个过程称为 Tokenization 。我们将一个个的 token(可以理解为小片段)表示向量,我们分词的目的就是尽可能的让这些向量蕴含更多有用的信息,然后把这些向量输入到算法模型中。由于一篇文本的词往往太多了,为了方便算法模型训练,我们会选取出频率 (也可能是其它的权重)最高的若干个词组成一个词表(Vocabulary) 。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="分词" scheme="https://blog.nicehuster.cn/tags/%E5%88%86%E8%AF%8D/"/>
</entry>
<entry>
<title>Deformable-DETR详解与代码解读</title>
<link href="https://blog.nicehuster.cn/2022/07/26/Deformable%20DETR/"/>
<id>https://blog.nicehuster.cn/2022/07/26/Deformable DETR/</id>
<published>2022-07-26T11:13:39.000Z</published>
<updated>2022-08-26T02:16:35.536Z</updated>
<content type="html"><![CDATA[<p>DETR是第一个end2end的目标检测器,不需要众多手工设计组件(anchor,iou匹配,nms后处理等),但也存在收敛慢,能处理的特征分辨率有限等缺陷。原因大概存在如下:</p><blockquote><ul><li>transformer在初始化时,分配给所有特征像素的注意力权重几乎均等;这就造成了模型需要长时间去学习关注真正有意义的位置,这些位置应该是稀疏的;</li><li>transformer在计算注意力权重时,伴随着高计算量与空间复杂度。特别是在编码器部分,与特征像素点的数量成平方级关系,因此难以处理高分辨率的特征;</li></ul></blockquote><a id="more"></a><p>Deformable DETR的工作就在于解决DETR收敛慢以及高计算复杂度问题。具体做法有:</p><h4 id="多尺度特征-amp-多尺度Embedding"><a href="#多尺度特征-amp-多尺度Embedding" class="headerlink" title="多尺度特征 & 多尺度Embedding"></a>多尺度特征 & 多尺度Embedding</h4><p>在DETR中,由于计算复杂度的问题,仅仅只使用了<strong>单尺度特征</strong>,对于特征点位置信息编码使用三角函数,不同位置对应不同编码值。而在多尺度特征中,位于不同特征层的特征点可能拥有相同的(h,w)坐标,使用一套位置编码是无法区分。因此作者使用多尺度特征时,增加了scale-level embedding,用于区分不同特征层。不同于三角函数使用固定公式计算编码,scale-level embedding是可学习的。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DeformableTransformer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model=<span class="number">256</span>, nhead=<span class="number">8</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> num_encoder_layers=<span class="number">6</span>, num_decoder_layers=<span class="number">6</span>, dim_feedforward=<span class="number">1024</span>, dropout=<span class="number">0.1</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> activation=<span class="string">"relu"</span>, return_intermediate_dec=False,</span></span></span><br><span class="line"><span class="function"><span class="params"> num_feature_levels=<span class="number">4</span>, dec_n_points=<span class="number">4</span>, enc_n_points=<span class="number">4</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> two_stage=False, two_stage_num_proposals=<span class="number">300</span>)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> ...</span><br><span class="line"> <span class="comment"># scale level embedding,对4个特征层分别附加d_model维度的embedding,用于区分query对应的具体特征层</span></span><br><span class="line"> self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))</span><br><span class="line"> ...</span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, srcs, masks, pos_embeds, query_embed=None)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> <span class="keyword">for</span> lvl, (src, mask, pos_embed) <span class="keyword">in</span> enumerate(zip(srcs, masks, pos_embeds)):</span><br><span class="line"> ...</span><br><span class="line"> <span class="comment"># (bs,c,h,w) --> (bs,h*w,c)</span></span><br><span class="line"> pos_embed = pos_embed.flatten(<span class="number">2</span>).transpose(<span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line"> <span class="comment">#与position embedding相加,且同一特征层,所有query的scale-level embedding是相同的</span></span><br><span class="line"> lvl_pos_embed = pos_embed + self.level_embed[lvl].view(<span class="number">1</span>, <span class="number">1</span>, <span class="number">-1</span>)</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><h4 id="Deformable-Attention"><a href="#Deformable-Attention" class="headerlink" title="Deformable Attention"></a>Deformable Attention</h4><p>通俗地来讲,可变性注意力即,query不是和全局每个位置的key都计算注意力权重,而仅在全局位置中采样部分位置的key,并且value也是基于这些位置进行采样插值得到,最后将局部&稀疏的注意力权重施加在对应的value上。</p><p><img src="/img/deformAttn.png" alt></p><p>如上图所示,每个query在每个head上采样K个位置,只需和这些位置的特征进行交互,不同于detr那样,每个query需要与全局位置进行交互。需要注意的是,位置偏移量$\Delta p_{mqx}$ 是由query经过全连接得到的,注意力权重也是由query经全连接层得到的,同时在K个采样点之间进行权重归一化。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MSDeformAttn</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model=<span class="number">256</span>, n_levels=<span class="number">4</span>, n_heads=<span class="number">8</span>, n_points=<span class="number">4</span>)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> self.n_points = n_points</span><br><span class="line"><span class="comment"># 采样点的坐标偏移量,每个query在每个head,level上都需要采样n_points个点(x,y)。</span></span><br><span class="line"> self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * <span class="number">2</span>)</span><br><span class="line"> <span class="comment"># 每个query对应的所有采样点的注意力权重</span></span><br><span class="line"> self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)</span><br><span class="line"> <span class="comment"># 线性变换得到value</span></span><br><span class="line"> self.value_proj = nn.Linear(d_model, d_model)</span><br><span class="line"> self.output_proj = nn.Linear(d_model, d_model)</span><br><span class="line"> ...</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None)</span>:</span></span><br><span class="line"> </span><br><span class="line"> ...</span><br><span class="line"> <span class="comment"># (N,len_in, d_model=256)</span></span><br><span class="line"> value = self.value_proj(input_flatten)</span><br><span class="line"> <span class="comment"># 将原图padding的部分用0填充</span></span><br><span class="line"> <span class="keyword">if</span> input_padding_mask <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span>:</span><br><span class="line"> value = value.masked_fill(input_padding_mask[..., <span class="keyword">None</span>], float(<span class="number">0</span>))</span><br><span class="line"> <span class="comment"># 拆分成多个head,(N,Len_in,8,64)</span></span><br><span class="line"> value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)</span><br><span class="line"> <span class="comment"># 预测采样点偏移量,(N,Len_in,8,4,4,2)</span></span><br><span class="line"> sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, <span class="number">2</span>)</span><br><span class="line"> <span class="comment"># 预测采样点注意力权重,(N,Len_in,8,4*4)</span></span><br><span class="line"> attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)</span><br><span class="line"> <span class="comment"># 权重归一化,对4个特征层分别采样的4个特征点,合计16个点,进行归一化</span></span><br><span class="line"> attention_weights = F.softmax(attention_weights, <span class="number">-1</span>).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)</span><br><span class="line"> <span class="comment"># N, Len_q, n_heads, n_levels, n_points, 2</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">2</span>: </span><br><span class="line"> <span class="comment"># (4,2) 其中每个是w,h</span></span><br><span class="line"> offset_normalizer = torch.stack([input_spatial_shapes[..., <span class="number">1</span>], input_spatial_shapes[..., <span class="number">0</span>]], <span class="number">-1</span>)</span><br><span class="line"> <span class="comment"># 对坐标偏移量使用对应特征层的宽高进行归一化然后和参考点坐标相加得到采样点坐标</span></span><br><span class="line"> sampling_locations = reference_points[:, :, <span class="keyword">None</span>, :, <span class="keyword">None</span>, :] \</span><br><span class="line"> + sampling_offsets / offset_normalizer[<span class="keyword">None</span>, <span class="keyword">None</span>, <span class="keyword">None</span>, :, <span class="keyword">None</span>, :]</span><br><span class="line"> <span class="keyword">elif</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">4</span>:</span><br><span class="line"> <span class="comment"># 最后一维度是4表示(cx,cy,w,h)</span></span><br><span class="line"> sampling_locations = reference_points[:, :, <span class="keyword">None</span>, :, <span class="keyword">None</span>, :<span class="number">2</span>] \</span><br><span class="line"> + sampling_offsets / self.n_points * reference_points[:, :, <span class="keyword">None</span>, :, <span class="keyword">None</span>, <span class="number">2</span>:] * <span class="number">0.5</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(</span><br><span class="line"> <span class="string">'Last dim of reference_points must be 2 or 4, but get {} instead.'</span>.format(reference_points.shape[<span class="number">-1</span>]))</span><br><span class="line"> <span class="comment"># 将注意力权重与value进行计算,是调用的self.im2col_step函数</span></span><br><span class="line"> output = MSDeformAttnFunction.apply(</span><br><span class="line"> value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)</span><br><span class="line"> <span class="comment"># 做线性变换得到最终输出结果</span></span><br><span class="line"> output = self.output_proj(output)</span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><h4 id="Deformable-Transformer"><a href="#Deformable-Transformer" class="headerlink" title="Deformable Transformer"></a>Deformable Transformer</h4><p>与DETR大体一致,主要区别在于用Deformable Attention替换了Encoder中的self-attn和Decoder中的cross-attn。</p><h5 id="Encoder前处理"><a href="#Encoder前处理" class="headerlink" title="Encoder前处理"></a>Encoder前处理</h5><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DeformableTransformer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> </span><br><span class="line"> ...</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, srcs, masks, pos_embeds, query_embed=None)</span>:</span></span><br><span class="line"> <span class="keyword">assert</span> self.two_stage <span class="keyword">or</span> query_embed <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># prepare input for encoder</span></span><br><span class="line"> src_flatten = []</span><br><span class="line"> mask_flatten = []</span><br><span class="line"> lvl_pos_embed_flatten = []</span><br><span class="line"> spatial_shapes = []</span><br><span class="line"> <span class="keyword">for</span> lvl, (src, mask, pos_embed) <span class="keyword">in</span> enumerate(zip(srcs, masks, pos_embeds)):</span><br><span class="line"> bs, c, h, w = src.shape</span><br><span class="line"> spatial_shape = (h, w)</span><br><span class="line"> spatial_shapes.append(spatial_shape)</span><br><span class="line"> src = src.flatten(<span class="number">2</span>).transpose(<span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line"> mask = mask.flatten(<span class="number">1</span>)</span><br><span class="line"> pos_embed = pos_embed.flatten(<span class="number">2</span>).transpose(<span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line"> lvl_pos_embed = pos_embed + self.level_embed[lvl].view(<span class="number">1</span>, <span class="number">1</span>, <span class="number">-1</span>)</span><br><span class="line"> lvl_pos_embed_flatten.append(lvl_pos_embed)</span><br><span class="line"> src_flatten.append(src)</span><br><span class="line"> mask_flatten.append(mask)</span><br><span class="line"> </span><br><span class="line"> <span class="comment">#多尺度特征flatten后进行拼接</span></span><br><span class="line"> src_flatten = torch.cat(src_flatten, <span class="number">1</span>)</span><br><span class="line"> <span class="comment">#多尺度mask图flatten后进行拼接</span></span><br><span class="line"> mask_flatten = torch.cat(mask_flatten, <span class="number">1</span>)</span><br><span class="line"> <span class="comment"># 多尺度位置信息flatten后进行拼接</span></span><br><span class="line"> lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, <span class="number">1</span>)</span><br><span class="line"> <span class="comment"># 记录每个特征图的尺度信息</span></span><br><span class="line"> spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)</span><br><span class="line"> <span class="comment"># 记录每个尺度特征图拼接后的起始索引</span></span><br><span class="line"> level_start_index = torch.cat((spatial_shapes.new_zeros((<span class="number">1</span>, )), spatial_shapes.prod(<span class="number">1</span>).cumsum(<span class="number">0</span>)[:<span class="number">-1</span>]))</span><br><span class="line"> <span class="comment"># 计算各个尺度特征图中非padding部分的边长占其边长的比例</span></span><br><span class="line"> valid_ratios = torch.stack([self.get_valid_ratio(m) <span class="keyword">for</span> m <span class="keyword">in</span> masks], <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># encoder</span></span><br><span class="line"> memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)</span><br><span class="line"> ...</span><br></pre></td></tr></table></figure><p>在DeformableTransformer的前向处理过程中,首先会对多尺度相关的元素进行flatten,这些输入元素包括:多尺度特征图、各尺度特征图对应的mask(指示哪些部分属于padding)、各尺度特征图对应的位置信息(<strong>position embedding + scale-level embedding</strong>),另外还有些辅助信息,比如:各尺度特征图的宽高、不同尺度特征对应于被flatten的那个维度的起始索引、各尺度特征图中非padding部分的边长占其边长的比例。</p><h5 id="Encoder编码"><a href="#Encoder编码" class="headerlink" title="Encoder编码"></a>Encoder编码</h5><p>经过Encoder前处理之后的信息就会经过Encoder进行编码,输出memory。下面代码展示的是Encoder的处理过程:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DeformableTransformerEncoder</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, encoder_layer, num_layers)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.layers = _get_clones(encoder_layer, num_layers)</span><br><span class="line"> self.num_layers = num_layers</span><br><span class="line"></span><br><span class="line"><span class="meta"> @staticmethod</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">get_reference_points</span><span class="params">(spatial_shapes, valid_ratios, device)</span>:</span></span><br><span class="line"> reference_points_list = []</span><br><span class="line"> <span class="keyword">for</span> lvl, (H_, W_) <span class="keyword">in</span> enumerate(spatial_shapes):</span><br><span class="line"></span><br><span class="line"> ref_y, ref_x = torch.meshgrid(torch.linspace(<span class="number">0.5</span>, H_ - <span class="number">0.5</span>, H_, dtype=torch.float32, device=device),</span><br><span class="line"> torch.linspace(<span class="number">0.5</span>, W_ - <span class="number">0.5</span>, W_, dtype=torch.float32, device=device))</span><br><span class="line"> ref_y = ref_y.reshape(<span class="number">-1</span>)[<span class="keyword">None</span>] / (valid_ratios[:, <span class="keyword">None</span>, lvl, <span class="number">1</span>] * H_)</span><br><span class="line"> ref_x = ref_x.reshape(<span class="number">-1</span>)[<span class="keyword">None</span>] / (valid_ratios[:, <span class="keyword">None</span>, lvl, <span class="number">0</span>] * W_)</span><br><span class="line"> ref = torch.stack((ref_x, ref_y), <span class="number">-1</span>)</span><br><span class="line"> reference_points_list.append(ref)</span><br><span class="line"> reference_points = torch.cat(reference_points_list, <span class="number">1</span>)</span><br><span class="line"> reference_points = reference_points[:, :, <span class="keyword">None</span>] * valid_ratios[:, <span class="keyword">None</span>]</span><br><span class="line"> <span class="keyword">return</span> reference_points</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None)</span>:</span></span><br><span class="line"> output = src</span><br><span class="line"> <span class="comment"># 参考点初始化,以0.5为步长,在特征图上密集采样所有点作为初始参考点</span></span><br><span class="line"> reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)</span><br><span class="line"> <span class="keyword">for</span> _, layer <span class="keyword">in</span> enumerate(self.layers):</span><br><span class="line"> output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><p>输出memory(编码后的特征表示),shape是 (bs, h_lvl1<em>w_lvl1+h_lvl2</em>w_lvl2+.., c=256),其中h_lvli和w_lvli分别代表第i层特征图的高和宽,于是第二个维度就是所有特征点的数量。编码后,特征的最后一个维度hidden_dim=256.</p><h5 id="Decoder前处理"><a href="#Decoder前处理" class="headerlink" title="Decoder前处理"></a>Decoder前处理</h5><p>对encoder的输出进行处理,得到参考点reference_points,需要说明下,在2-stage模式下,参考点和输入到Decoder的object query及query embedding的生成方式和形式会有所不同。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DeformableTransformer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">gen_encoder_output_proposals</span><span class="params">(self, memory, memory_padding_mask, spatial_shapes)</span>:</span></span><br><span class="line"> N_, S_, C_ = memory.shape</span><br><span class="line"> base_scale = <span class="number">4.0</span></span><br><span class="line"> proposals = []</span><br><span class="line"> _cur = <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> lvl, (H_, W_) <span class="keyword">in</span> enumerate(spatial_shapes):</span><br><span class="line"> mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, <span class="number">1</span>)</span><br><span class="line"> valid_H = torch.sum(~mask_flatten_[:, :, <span class="number">0</span>, <span class="number">0</span>], <span class="number">1</span>)</span><br><span class="line"> valid_W = torch.sum(~mask_flatten_[:, <span class="number">0</span>, :, <span class="number">0</span>], <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> grid_y, grid_x = torch.meshgrid(torch.linspace(<span class="number">0</span>, H_ - <span class="number">1</span>, H_, dtype=torch.float32, device=memory.device),</span><br><span class="line"> torch.linspace(<span class="number">0</span>, W_ - <span class="number">1</span>, W_, dtype=torch.float32, device=memory.device))</span><br><span class="line"> grid = torch.cat([grid_x.unsqueeze(<span class="number">-1</span>), grid_y.unsqueeze(<span class="number">-1</span>)], <span class="number">-1</span>)</span><br><span class="line"></span><br><span class="line"> scale = torch.cat([valid_W.unsqueeze(<span class="number">-1</span>), valid_H.unsqueeze(<span class="number">-1</span>)], <span class="number">1</span>).view(N_, <span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>)</span><br><span class="line"> grid = (grid.unsqueeze(<span class="number">0</span>).expand(N_, <span class="number">-1</span>, <span class="number">-1</span>, <span class="number">-1</span>) + <span class="number">0.5</span>) / scale</span><br><span class="line"> wh = torch.ones_like(grid) * <span class="number">0.05</span> * (<span class="number">2.0</span> ** lvl)</span><br><span class="line"> proposal = torch.cat((grid, wh), <span class="number">-1</span>).view(N_, <span class="number">-1</span>, <span class="number">4</span>)</span><br><span class="line"> proposals.append(proposal)</span><br><span class="line"> _cur += (H_ * W_)</span><br><span class="line"> output_proposals = torch.cat(proposals, <span class="number">1</span>)</span><br><span class="line"> output_proposals_valid = ((output_proposals > <span class="number">0.01</span>) & (output_proposals < <span class="number">0.99</span>)).all(<span class="number">-1</span>, keepdim=<span class="keyword">True</span>)</span><br><span class="line"> output_proposals = torch.log(output_proposals / (<span class="number">1</span> - output_proposals))</span><br><span class="line"> output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(<span class="number">-1</span>), float(<span class="string">'inf'</span>))</span><br><span class="line"> output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(<span class="string">'inf'</span>))</span><br><span class="line"></span><br><span class="line"> output_memory = memory</span><br><span class="line"> output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(<span class="number">-1</span>), float(<span class="number">0</span>))</span><br><span class="line"> output_memory = output_memory.masked_fill(~output_proposals_valid, float(<span class="number">0</span>))</span><br><span class="line"> output_memory = self.enc_output_norm(self.enc_output(output_memory))</span><br><span class="line"> <span class="keyword">return</span> output_memory, output_proposals</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, srcs, masks, pos_embeds, query_embed=None)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> <span class="comment"># encoder</span></span><br><span class="line"> memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># prepare input for decoder</span></span><br><span class="line"> bs, _, c = memory.shape</span><br><span class="line"> <span class="keyword">if</span> self.two_stage:</span><br><span class="line"> <span class="comment"># 生成proposal,对encoder的输出进行处理(全连接层+归一化),output_proposals对应于特征图上各个初始参考点位置(固定的)</span></span><br><span class="line"> output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)</span><br><span class="line"> </span><br><span class="line"></span><br><span class="line"> <span class="comment"># hack implementation for two-stage Deformable DETR</span></span><br><span class="line"> <span class="comment"># 借助于decoder的最后一层的class_embed和bbox_embed获取分数和proposal</span></span><br><span class="line"> enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)</span><br><span class="line"> <span class="comment"># bbox_embed预测的是相对于初始参考点位置的偏移量,所以需要加上初始参考点位置</span></span><br><span class="line"> enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals</span><br><span class="line"></span><br><span class="line"> topk = self.two_stage_num_proposals</span><br><span class="line"> <span class="comment"># 选取分数topk的proposal</span></span><br><span class="line"> topk_proposals = torch.topk(enc_outputs_class[..., <span class="number">0</span>], topk, dim=<span class="number">1</span>)[<span class="number">1</span>]</span><br><span class="line"> topk_coords_unact = torch.gather(enc_outputs_coord_unact, <span class="number">1</span>, topk_proposals.unsqueeze(<span class="number">-1</span>).repeat(<span class="number">1</span>, <span class="number">1</span>, <span class="number">4</span>))</span><br><span class="line"> topk_coords_unact = topk_coords_unact.detach()</span><br><span class="line"> reference_points = topk_coords_unact.sigmoid()</span><br><span class="line"> init_reference_out = reference_points</span><br><span class="line"> pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))</span><br><span class="line"> query_embed, tgt = torch.split(pos_trans_out, c, dim=<span class="number">2</span>)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> query_embed, tgt = torch.split(query_embed, c, dim=<span class="number">1</span>)</span><br><span class="line"> query_embed = query_embed.unsqueeze(<span class="number">0</span>).expand(bs, <span class="number">-1</span>, <span class="number">-1</span>)</span><br><span class="line"> tgt = tgt.unsqueeze(<span class="number">0</span>).expand(bs, <span class="number">-1</span>, <span class="number">-1</span>)</span><br><span class="line"> reference_points = self.reference_points(query_embed).sigmoid()</span><br><span class="line"> init_reference_out = reference_points</span><br><span class="line"></span><br><span class="line"> <span class="comment"># decoder</span></span><br><span class="line"> hs, inter_references = self.decoder(tgt, reference_points, memory,</span><br><span class="line"> spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)</span><br><span class="line"></span><br><span class="line"> inter_references_out = inter_references</span><br><span class="line"> <span class="keyword">if</span> self.two_stage:</span><br><span class="line"> <span class="keyword">return</span> hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact</span><br><span class="line"> <span class="keyword">return</span> hs, init_reference_out, inter_references_out, <span class="keyword">None</span>, <span class="keyword">None</span></span><br></pre></td></tr></table></figure><blockquote><ul><li>如果是two-stage 模式下,参考点由Encoder预测topk得分最高的proposal box(这时的参考点是4d,bbox形式),然后对参考点进行position embedding来生成Decoder需要的object query和对应的query embedding;</li><li>非two-stage模式下,Decoder的 object query(target )和 query embedding 就是预设的embedding,然后将query embedding经过全连接层输出2d参考点,这时的参考点是归一化的中心坐标形式。</li></ul></blockquote><p>另外,两种情况下生成的参考点数量可能不同:2-stage时是有top-k(作者设置为300)个,而1-stage时是<em>num_queries</em>(作者也设置为300)个,也就是和object query的数量一致(可以理解为,此时参考点就是object query本身的位置)。</p><h5 id="Decoder解码"><a href="#Decoder解码" class="headerlink" title="Decoder解码"></a>Decoder解码</h5><p>这里与Transformer中主要的区别在于使用可变形注意力替代了原生的交叉注意力。类似地,每层的解码过程是self-attention+cross-attention+ffn,下一层输入的object query是上一层输出的解码特征。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DeformableTransformerDecoder</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> ...</span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,</span></span></span><br><span class="line"><span class="function"><span class="params"> query_pos=None, src_padding_mask=None)</span>:</span></span><br><span class="line"> output = tgt</span><br><span class="line"></span><br><span class="line"> intermediate = []</span><br><span class="line"> intermediate_reference_points = []</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> lid, layer <span class="keyword">in</span> enumerate(self.layers):</span><br><span class="line"> <span class="keyword">if</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">4</span>:</span><br><span class="line"> reference_points_input = reference_points[:, :, <span class="keyword">None</span>] \</span><br><span class="line"> * torch.cat([src_valid_ratios, src_valid_ratios], <span class="number">-1</span>)[:, <span class="keyword">None</span>]</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">assert</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">2</span></span><br><span class="line"> reference_points_input = reference_points[:, :, <span class="keyword">None</span>] * src_valid_ratios[:, <span class="keyword">None</span>]</span><br><span class="line"> output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># hack implementation for iterative bounding box refinement</span></span><br><span class="line"> <span class="comment"># 使用iterative bbox refinement</span></span><br><span class="line"> <span class="keyword">if</span> self.bbox_embed <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span>:</span><br><span class="line"> <span class="comment"># 得到相对初始参考点的偏移量</span></span><br><span class="line"> tmp = self.bbox_embed[lid](output)</span><br><span class="line"> <span class="comment"># 得到归一化坐标点</span></span><br><span class="line"> <span class="keyword">if</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">4</span>:</span><br><span class="line"> new_reference_points = tmp + inverse_sigmoid(reference_points)</span><br><span class="line"> new_reference_points = new_reference_points.sigmoid()</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">assert</span> reference_points.shape[<span class="number">-1</span>] == <span class="number">2</span></span><br><span class="line"> new_reference_points = tmp</span><br><span class="line"> new_reference_points[..., :<span class="number">2</span>] = tmp[..., :<span class="number">2</span>] + inverse_sigmoid(reference_points)</span><br><span class="line"> new_reference_points = new_reference_points.sigmoid()</span><br><span class="line"> <span class="comment"># 在输入下一层之前取消梯度,将当前层的预测bbox作为下一层的初始参考点</span></span><br><span class="line"> reference_points = new_reference_points.detach()</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.return_intermediate:</span><br><span class="line"> intermediate.append(output)</span><br><span class="line"> intermediate_reference_points.append(reference_points)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.return_intermediate:</span><br><span class="line"> <span class="keyword">return</span> torch.stack(intermediate), torch.stack(intermediate_reference_points)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> output, reference_points</span><br></pre></td></tr></table></figure><p>进行Decoder解码完之后就是接class_embed和bbox_embed得到最后box分数和坐标。在上面需要注意的一点是,<strong>每次refine后的bbox梯度是不会传递到下一层</strong>。</p><h4 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h4><ol><li>相比于DETR,使用了多尺度特征+scale-level embedding,用于区分不同特征层;</li><li>使用了多尺度可变形注意力替代Encoder中的selfattn和Decoder中的crossattn,减小计算量;</li><li>引入了参考点,类似引入先验知识;</li><li>设计了两阶段模式和iteraitive box refinement策略;</li><li>检测头回归分支预测是bbox相对参考点的偏移量而非绝对坐标值;</li></ol>]]></content>
<summary type="html">
<p>DETR是第一个end2end的目标检测器,不需要众多手工设计组件(anchor,iou匹配,nms后处理等),但也存在收敛慢,能处理的特征分辨率有限等缺陷。原因大概存在如下:</p>
<blockquote>
<ul>
<li>transformer在初始化时,分配给所有特征像素的注意力权重几乎均等;这就造成了模型需要长时间去学习关注真正有意义的位置,这些位置应该是稀疏的;</li>
<li>transformer在计算注意力权重时,伴随着高计算量与空间复杂度。特别是在编码器部分,与特征像素点的数量成平方级关系,因此难以处理高分辨率的特征;</li>
</ul>
</blockquote>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="目标检测" scheme="https://blog.nicehuster.cn/tags/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
</entry>
<entry>
<title>DETR源码解读</title>
<link href="https://blog.nicehuster.cn/2022/07/24/DETR%E4%BB%A3%E7%A0%81%E8%A7%A3%E8%AF%BB/"/>
<id>https://blog.nicehuster.cn/2022/07/24/DETR代码解读/</id>
<published>2022-07-24T11:13:39.000Z</published>
<updated>2022-08-26T02:16:49.016Z</updated>
<content type="html"><![CDATA[<p>transformer由encoder和decoder俩部分组成。</p><a id="more"></a><h4 id="Encoder"><a href="#Encoder" class="headerlink" title="Encoder"></a>Encoder</h4><p>一个encoder由多个encoder_layer组成,在detr中默认是6层。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">TransformerEncoder</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, encoder_layer, num_layers, norm=None)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.layers = _get_clones(encoder_layer, num_layers) <span class="comment"># Encoder包含num层,每层具有相同结构encoder_layer</span></span><br><span class="line"> self.num_layers = num_layers</span><br><span class="line"></span><br><span class="line"> self.norm = norm <span class="comment"># 归一化</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, src,</span></span></span><br><span class="line"><span class="function"><span class="params"> mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># src 对应backbone最后一层输出的feature maps,并且维度已经映射到(h*w,bs, hidden_dim)</span></span><br><span class="line"> <span class="comment"># mask 一般为空</span></span><br><span class="line"> <span class="comment"># pos 对应backbone最后一层输出的feature maps对应的位置编码,shape是(h*w,bs,c)</span></span><br><span class="line"> <span class="comment"># src_key_padding_mask 对应backbone最后一层输出的feature maps对应的mask,shape是(bs,h*w)</span></span><br><span class="line"></span><br><span class="line"> output = src</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> layer <span class="keyword">in</span> self.layers:</span><br><span class="line"> output = layer(output, src_mask=mask,</span><br><span class="line"> src_key_padding_mask=src_key_padding_mask, pos=pos)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.norm <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span>:</span><br><span class="line"> output = self.norm(output)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><p>EncoderLayer 的前向过程分为两种情况,一种是在输入多头自注意力层和前向反馈层前先进行归一化,另一种则是在这两个层输出后再进行归一化操作。对应实现可以参考如下图左侧部分:</p><p><img src="/img/detr-trans.png" alt="detr"></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">TransformerEncoderLayer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model, nhead, dim_feedforward=<span class="number">2048</span>, dropout=<span class="number">0.1</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> activation=<span class="string">"relu"</span>, normalize_before=False)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"></span><br><span class="line"> <span class="comment">#多头自注意力模块</span></span><br><span class="line"> self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)</span><br><span class="line"> <span class="comment"># Implementation of Feedforward model</span></span><br><span class="line"> self.linear1 = nn.Linear(d_model, dim_feedforward)</span><br><span class="line"> self.dropout = nn.Dropout(dropout)</span><br><span class="line"> self.linear2 = nn.Linear(dim_feedforward, d_model)</span><br><span class="line"></span><br><span class="line"> self.norm1 = nn.LayerNorm(d_model)</span><br><span class="line"> self.norm2 = nn.LayerNorm(d_model)</span><br><span class="line"> self.dropout1 = nn.Dropout(dropout)</span><br><span class="line"> self.dropout2 = nn.Dropout(dropout)</span><br><span class="line"></span><br><span class="line"> self.activation = _get_activation_fn(activation)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 是否需要在输入多头自注意力层之前进行归一化</span></span><br><span class="line"> self.normalize_before = normalize_before</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">with_pos_embed</span><span class="params">(self, tensor, pos: Optional[Tensor])</span>:</span></span><br><span class="line"> <span class="keyword">return</span> tensor <span class="keyword">if</span> pos <span class="keyword">is</span> <span class="keyword">None</span> <span class="keyword">else</span> tensor + pos</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward_post</span><span class="params">(self,</span></span></span><br><span class="line"><span class="function"><span class="params"> src,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> q = k = self.with_pos_embed(src, pos)</span><br><span class="line"> src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,</span><br><span class="line"> key_padding_mask=src_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> src = src + self.dropout1(src2)</span><br><span class="line"> src = self.norm1(src)</span><br><span class="line"> src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))</span><br><span class="line"> src = src + self.dropout2(src2)</span><br><span class="line"> src = self.norm2(src)</span><br><span class="line"> <span class="keyword">return</span> src</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward_pre</span><span class="params">(self, src,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 输入多头自注意力层前进行归一化</span></span><br><span class="line"> src2 = self.norm1(src)</span><br><span class="line"> <span class="comment"># q,k在输入attn之前需要结合位置编码</span></span><br><span class="line"> q = k = self.with_pos_embed(src2, pos)</span><br><span class="line"> <span class="comment"># self.self_attn是nn.MultiheadAttention的实例,其前向过程返回两部分,第一个是自注意力层的输出,第二个是自注意力权重,因此这里取了输出索引为0的部分即代表自注意力层的输出。</span></span><br><span class="line"> src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,</span><br><span class="line"> key_padding_mask=src_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> </span><br><span class="line"> src = src + self.dropout1(src2)</span><br><span class="line"> src2 = self.norm2(src)</span><br><span class="line"> src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))</span><br><span class="line"> src = src + self.dropout2(src2)</span><br><span class="line"> <span class="keyword">return</span> src</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, src,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> src_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 俩种不同的前向过程</span></span><br><span class="line"> <span class="keyword">if</span> self.normalize_before:</span><br><span class="line"> <span class="keyword">return</span> self.forward_pre(src, src_mask, src_key_padding_mask, pos)</span><br><span class="line"> <span class="keyword">return</span> self.forward_post(src, src_mask, src_key_padding_mask, pos)</span><br></pre></td></tr></table></figure><p>需要注意的是,在输入多头自注意力层时需要先进行位置嵌入,即结合位置编码。注意仅对query和key实施,而value不需要。query和key是在图像特征中各个位置之间计算相关性,而value作为原图像特征,使用计算出来的相关性加权上去,得到各位置结合了全局相关性(增强/削弱)后的特征表示。</p><h4 id="Query-Embedding"><a href="#Query-Embedding" class="headerlink" title="Query Embedding"></a>Query Embedding</h4><p>在解析Decoder前,有必要先简要地谈谈query embedding,因为它是Decoder的主要输入之一。query embedding 有点anchor的味道,而且是自学习的anchor,作者使用了<em>nn.Embedding</em>实现:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">self.query_embed = nn.Embedding(num_queries, hidden_dim)</span><br></pre></td></tr></table></figure><p>其中num_queries 代表图像中有多少个目标(位置),默认是100个,对这些目标(位置)全部进行嵌入,维度映射到 <em>hidden_dim</em>,将 <strong>query_embedding 的权重</strong>作为参数输入到Transformer的前向过程,使用时与position encoding的方式相同:直接相加。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[<span class="number">-1</span>])[<span class="number">0</span>]</span><br></pre></td></tr></table></figure><p>而这个query embedding应该加在哪呢?当然是我们需要预测的目标(query object)咯!可是网络一开始还没有输出,我们都不知道预测目标在哪里呀,如何将它实体化?作者也不知道,于是就简单粗暴地直接将它初始化为全0,shape和query embedding 的权重一致(从而可以element-wise add)。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Transformer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model=<span class="number">512</span>, nhead=<span class="number">8</span>, num_encoder_layers=<span class="number">6</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> num_decoder_layers=<span class="number">6</span>, dim_feedforward=<span class="number">2048</span>, dropout=<span class="number">0.1</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> activation=<span class="string">"relu"</span>, normalize_before=False,</span></span></span><br><span class="line"><span class="function"><span class="params"> return_intermediate_dec=False)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> ...</span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, src, mask, query_embed, pos_embed)</span>:</span></span><br><span class="line"> </span><br><span class="line">...</span><br><span class="line"> <span class="comment"># (num_queries,bs,hidden_dim)</span></span><br><span class="line"> tgt = torch.zeros_like(query_embed) <span class="comment">#</span></span><br><span class="line"> memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)</span><br><span class="line"> hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,</span><br><span class="line"> pos=pos_embed, query_pos=query_embed)</span><br><span class="line"> <span class="keyword">return</span> hs.transpose(<span class="number">1</span>, <span class="number">2</span>), memory.permute(<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>).view(bs, c, h, w)</span><br></pre></td></tr></table></figure><h4 id="Decoder"><a href="#Decoder" class="headerlink" title="Decoder"></a>Decoder</h4><p>Decoder的结构和Encoder十分类似。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">TransformerDecoder</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, decoder_layer, num_layers, norm=None, return_intermediate=False)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.layers = _get_clones(decoder_layer, num_layers)</span><br><span class="line"> self.num_layers = num_layers</span><br><span class="line"> self.norm = norm</span><br><span class="line"> self.return_intermediate = return_intermediate</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, tgt, memory,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> query_pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># tgt 是query embedding,shape是(num_queries,bs,hidden_dim)</span></span><br><span class="line"> <span class="comment"># query_pos 是对应tgt的位置编码,shape和tgt一致</span></span><br><span class="line"> <span class="comment"># memory是encoder的输出,shape是(h*w,bs,hidden_dim)</span></span><br><span class="line"> <span class="comment"># memory_key_padding_mask是对应encoder的src_key_padding_mask,shape是(bs,h*w)</span></span><br><span class="line"> <span class="comment"># pos 对应输入到encoder的位置编码,这里代表memory的位置编码,shape和memory一致</span></span><br><span class="line"> </span><br><span class="line"> output = tgt</span><br><span class="line"></span><br><span class="line"> intermediate = []</span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> layer <span class="keyword">in</span> self.layers:</span><br><span class="line"> output = layer(output, memory, tgt_mask=tgt_mask,</span><br><span class="line"> memory_mask=memory_mask,</span><br><span class="line"> tgt_key_padding_mask=tgt_key_padding_mask,</span><br><span class="line"> memory_key_padding_mask=memory_key_padding_mask,</span><br><span class="line"> pos=pos, query_pos=query_pos)</span><br><span class="line"> <span class="keyword">if</span> self.return_intermediate:</span><br><span class="line"> intermediate.append(self.norm(output))</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.norm <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span>:</span><br><span class="line"> output = self.norm(output)</span><br><span class="line"> <span class="keyword">if</span> self.return_intermediate:</span><br><span class="line"> intermediate.pop()</span><br><span class="line"> intermediate.append(output)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.return_intermediate:</span><br><span class="line"> <span class="keyword">return</span> torch.stack(intermediate)</span><br><span class="line"> <span class="keyword">return</span> output.unsqueeze(<span class="number">0</span>)</span><br></pre></td></tr></table></figure><p>注意,在detr中,tgt_mask和memory_mask并未使用。需要注意的是intermediate中记录的是每层输出后的归一化结果,而每一层的输入是前一层输出(没有归一化)的结果。</p><p>DecoderLayer与Encoder的实现类似,只不过多了一层cross attention,其实质也是多头自注意力层,但是key和value来自于Encoder的输出。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">TransformerDecoderLayer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model, nhead, dim_feedforward=<span class="number">2048</span>, dropout=<span class="number">0.1</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> activation=<span class="string">"relu"</span>, normalize_before=False)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)</span><br><span class="line"> self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)</span><br><span class="line"> <span class="comment"># Implementation of Feedforward model</span></span><br><span class="line"> self.linear1 = nn.Linear(d_model, dim_feedforward)</span><br><span class="line"> self.dropout = nn.Dropout(dropout)</span><br><span class="line"> self.linear2 = nn.Linear(dim_feedforward, d_model)</span><br><span class="line"></span><br><span class="line"> self.norm1 = nn.LayerNorm(d_model)</span><br><span class="line"> self.norm2 = nn.LayerNorm(d_model)</span><br><span class="line"> self.norm3 = nn.LayerNorm(d_model)</span><br><span class="line"> self.dropout1 = nn.Dropout(dropout)</span><br><span class="line"> self.dropout2 = nn.Dropout(dropout)</span><br><span class="line"> self.dropout3 = nn.Dropout(dropout)</span><br><span class="line"></span><br><span class="line"> self.activation = _get_activation_fn(activation)</span><br><span class="line"> self.normalize_before = normalize_before</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">with_pos_embed</span><span class="params">(self, tensor, pos: Optional[Tensor])</span>:</span></span><br><span class="line"> <span class="keyword">return</span> tensor <span class="keyword">if</span> pos <span class="keyword">is</span> <span class="keyword">None</span> <span class="keyword">else</span> tensor + pos</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward_post</span><span class="params">(self, tgt, memory,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> query_pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> q = k = self.with_pos_embed(tgt, query_pos)</span><br><span class="line"> tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,</span><br><span class="line"> key_padding_mask=tgt_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> tgt = tgt + self.dropout1(tgt2)</span><br><span class="line"> tgt = self.norm1(tgt)</span><br><span class="line"> tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),</span><br><span class="line"> key=self.with_pos_embed(memory, pos),</span><br><span class="line"> value=memory, attn_mask=memory_mask,</span><br><span class="line"> key_padding_mask=memory_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> tgt = tgt + self.dropout2(tgt2)</span><br><span class="line"> tgt = self.norm2(tgt)</span><br><span class="line"> tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))</span><br><span class="line"> tgt = tgt + self.dropout3(tgt2)</span><br><span class="line"> tgt = self.norm3(tgt)</span><br><span class="line"> <span class="keyword">return</span> tgt</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward_pre</span><span class="params">(self, tgt, memory,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> query_pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> </span><br><span class="line"> tgt2 = self.norm1(tgt)</span><br><span class="line"> <span class="comment"># 进行位置嵌入</span></span><br><span class="line"> q = k = self.with_pos_embed(tgt2, query_pos)</span><br><span class="line"> <span class="comment"># 多头自注意力层,输入不包含encoder的输出</span></span><br><span class="line"> tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,</span><br><span class="line"> key_padding_mask=tgt_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> tgt = tgt + self.dropout1(tgt2)</span><br><span class="line"> tgt2 = self.norm2(tgt)</span><br><span class="line"> <span class="comment"># cross attention,key,value来自encoder,query来自上一层输出</span></span><br><span class="line"> <span class="comment"># key,query均需进行位置嵌入</span></span><br><span class="line"> tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),</span><br><span class="line"> key=self.with_pos_embed(memory, pos),</span><br><span class="line"> value=memory, attn_mask=memory_mask,</span><br><span class="line"> key_padding_mask=memory_key_padding_mask)[<span class="number">0</span>]</span><br><span class="line"> tgt = tgt + self.dropout2(tgt2)</span><br><span class="line"> tgt2 = self.norm3(tgt)</span><br><span class="line"> tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))</span><br><span class="line"> tgt = tgt + self.dropout3(tgt2)</span><br><span class="line"> <span class="keyword">return</span> tgt</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, tgt, memory,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> tgt_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> memory_key_padding_mask: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> pos: Optional[Tensor] = None,</span></span></span><br><span class="line"><span class="function"><span class="params"> query_pos: Optional[Tensor] = None)</span>:</span></span><br><span class="line"> <span class="keyword">if</span> self.normalize_before:</span><br><span class="line"> <span class="keyword">return</span> self.forward_pre(tgt, memory, tgt_mask, memory_mask,</span><br><span class="line"> tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)</span><br><span class="line"> <span class="keyword">return</span> self.forward_post(tgt, memory, tgt_mask, memory_mask,</span><br><span class="line"> tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)</span><br></pre></td></tr></table></figure><p>注意,在tgt在输入到self_attn之前,需要经过position embedding,tgt+query_pos。在第二个多头注意力模块multihead_attn上,key和value均来自Encoder的输出。同样地,query和key要进行位置嵌入(而value不用)。这里cross attention计算的相关性是目标物体与图像特征各位置的相关性,然后再把这个相关性系数加权到Encoder编码后的图像特征(value)上,相当于获得了object features的意思,更好地表征了图像中的各个物体。从上面encoder和decoder的实现可以看出,作者非常强调位置嵌入的作用,每次进行attention计算前都需要进行position embedding,究其原因是因为transformer的转置不变性,即对排列和位置是不care的,然而在detection任务中却是十分重要的。</p><h4 id="Transformer"><a href="#Transformer" class="headerlink" title="Transformer"></a>Transformer</h4><p>将Encoder和Decoder封装在一起构成Transformer。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Transformer</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, d_model=<span class="number">512</span>, nhead=<span class="number">8</span>, num_encoder_layers=<span class="number">6</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> num_decoder_layers=<span class="number">6</span>, dim_feedforward=<span class="number">2048</span>, dropout=<span class="number">0.1</span>,</span></span></span><br><span class="line"><span class="function"><span class="params"> activation=<span class="string">"relu"</span>, normalize_before=False,</span></span></span><br><span class="line"><span class="function"><span class="params"> return_intermediate_dec=False)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"><span class="comment"># 构建encoder layer</span></span><br><span class="line"> encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,</span><br><span class="line"> dropout, activation, normalize_before)</span><br><span class="line"> encoder_norm = nn.LayerNorm(d_model) <span class="keyword">if</span> normalize_before <span class="keyword">else</span> <span class="keyword">None</span></span><br><span class="line"> self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)</span><br><span class="line"><span class="comment">#构建decoder layer</span></span><br><span class="line"> decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,</span><br><span class="line"> dropout, activation, normalize_before)</span><br><span class="line"> decoder_norm = nn.LayerNorm(d_model)</span><br><span class="line"> self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,</span><br><span class="line"> return_intermediate=return_intermediate_dec)</span><br><span class="line"></span><br><span class="line"> self._reset_parameters()</span><br><span class="line"></span><br><span class="line"> self.d_model = d_model <span class="comment"># 输入的embedding的特征维度</span></span><br><span class="line"> self.nhead = nhead <span class="comment">#</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">_reset_parameters</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> p <span class="keyword">in</span> self.parameters():</span><br><span class="line"> <span class="keyword">if</span> p.dim() > <span class="number">1</span>:</span><br><span class="line"> nn.init.xavier_uniform_(p)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, src, mask, query_embed, pos_embed)</span>:</span></span><br><span class="line"> <span class="comment"># flatten NxCxHxW to HWxNxC</span></span><br><span class="line"> bs, c, h, w = src.shape</span><br><span class="line"> <span class="comment"># 将backbone输入的feature maps进行flatten成序列,</span></span><br><span class="line"> <span class="comment"># src: (h*w,bs,c)</span></span><br><span class="line"> src = src.flatten(<span class="number">2</span>).permute(<span class="number">2</span>, <span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line"> <span class="comment"># pos: (h*w,bs,hidden_dim)</span></span><br><span class="line"> pos_embed = pos_embed.flatten(<span class="number">2</span>).permute(<span class="number">2</span>, <span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line"> <span class="comment"># query_embed: (num_queries, bs, hidden_dim)</span></span><br><span class="line"> query_embed = query_embed.unsqueeze(<span class="number">1</span>).repeat(<span class="number">1</span>, bs, <span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># mask: (bs, h*w)</span></span><br><span class="line"> mask = mask.flatten(<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> tgt = torch.zeros_like(query_embed) <span class="comment"># 每次forward时,tgt都会初始化为0</span></span><br><span class="line"> <span class="comment"># memory: (h*w, bs, c)</span></span><br><span class="line"> memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)</span><br><span class="line"> <span class="comment"># TransformerDecoderLayer中return_intermediate设置为true,因此decoder包含了每层的输出结果,因此hs的shape是(6, num_queries,bs,hidden_dim)</span></span><br><span class="line"> hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,</span><br><span class="line"> pos=pos_embed, query_pos=query_embed)</span><br><span class="line"> <span class="keyword">return</span> hs.transpose(<span class="number">1</span>, <span class="number">2</span>), memory.permute(<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>).view(bs, c, h, w)</span><br></pre></td></tr></table></figure><p>注意,tgt是与query embedding形状一直且设置为全0的结果,意为初始化需要预测的目标。因为一开始并不清楚这些目标,所以初始化为全0。其会在Decoder的各层不断被refine,相当于一个coarse-to-fine的过程,但是真正要学习的是query embedding,学习到的是整个数据集中目标物体的统计特征,而tgt在每次迭代训练(一个batch数据刚到来)时会被重新初始化为0。</p><h4 id="DETR"><a href="#DETR" class="headerlink" title="DETR"></a>DETR</h4><p>DETR包含backbone,encoder, decoder, prediction heads四个部分。encoder和decoder通常会用一个transformer来实现。prediction heads部分包括分类和回归。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DETR</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="string">""" This is the DETR module that performs object detection """</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, backbone, transformer, num_classes, num_queries, aux_loss=False)</span>:</span></span><br><span class="line"> <span class="string">""" Initializes the model.</span></span><br><span class="line"><span class="string"> Parameters:</span></span><br><span class="line"><span class="string"> backbone: torch module of the backbone to be used. See backbone.py</span></span><br><span class="line"><span class="string"> transformer: torch module of the transformer architecture. See transformer.py</span></span><br><span class="line"><span class="string"> num_classes: number of object classes</span></span><br><span class="line"><span class="string"> num_queries: number of object queries, ie detection slot. This is the maximal number of objects</span></span><br><span class="line"><span class="string"> DETR can detect in a single image. For COCO, we recommend 100 queries.</span></span><br><span class="line"><span class="string"> aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.num_queries = num_queries</span><br><span class="line"> self.transformer = transformer</span><br><span class="line"> hidden_dim = transformer.d_model</span><br><span class="line"> <span class="comment"># class分类</span></span><br><span class="line"> self.class_embed = nn.Linear(hidden_dim, num_classes + <span class="number">1</span>)</span><br><span class="line"> <span class="comment"># box回归,包含3层nn.linear(),最后一层维度映射为4,代表bbox的中心点横、纵坐标和宽、高。</span></span><br><span class="line"> self.bbox_embed = MLP(hidden_dim, hidden_dim, <span class="number">4</span>, <span class="number">3</span>)</span><br><span class="line"> <span class="comment"># query_embed用于在Transformer中对初始化query以及对其编码生成嵌入</span></span><br><span class="line"> self.query_embed = nn.Embedding(num_queries, hidden_dim)</span><br><span class="line"><span class="comment"># input_proj是将CNN提取的特征维度映射到Transformer隐层的维度;</span></span><br><span class="line"> self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=<span class="number">1</span>)</span><br><span class="line"> self.backbone = backbone</span><br><span class="line"> self.aux_loss = aux_loss</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, samples: NestedTensor)</span>:</span></span><br><span class="line"> <span class="comment"># 将sample转换成nestedTensor类型</span></span><br><span class="line"> <span class="keyword">if</span> isinstance(samples, (list, torch.Tensor)):</span><br><span class="line"> samples = nested_tensor_from_tensor_list(samples)</span><br><span class="line"> <span class="comment"># 输入cnn提取特征,并输出pos encoding </span></span><br><span class="line"> features, pos = self.backbone(samples)</span><br><span class="line"><span class="comment"># 取出最后一层特征及对应mask</span></span><br><span class="line"> src, mask = features[<span class="number">-1</span>].decompose()</span><br><span class="line"> <span class="keyword">assert</span> mask <span class="keyword">is</span> <span class="keyword">not</span> <span class="keyword">None</span></span><br><span class="line"> hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[<span class="number">-1</span>])[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 生成分类与回归的预测结果</span></span><br><span class="line"> outputs_class = self.class_embed(hs)</span><br><span class="line"> outputs_coord = self.bbox_embed(hs).sigmoid()</span><br><span class="line"> <span class="comment"># 由于hs包含transformer中decoder每层输出,因此索引-1表示取最后一层输出</span></span><br><span class="line"> out = {<span class="string">'pred_logits'</span>: outputs_class[<span class="number">-1</span>], <span class="string">'pred_boxes'</span>: outputs_coord[<span class="number">-1</span>]}</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> self.aux_loss:</span><br><span class="line"> out[<span class="string">'aux_outputs'</span>] = self._set_aux_loss(outputs_class, outputs_coord)</span><br><span class="line"> <span class="keyword">return</span> out</span><br></pre></td></tr></table></figure><h4 id="Postprocess"><a href="#Postprocess" class="headerlink" title="Postprocess"></a>Postprocess</h4><p>一部分DETR的输出并不是最终预测结果的形式,还需要进行简单的后处理。但是这里的后处理并不是NMS哦!DETR预测的是集合,并且在训练过程中经过匈牙利算法与GT一对一匹配学习,因此不存在重复框的情况。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">PostProcess</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="string">""" This module converts the model's output into the format expected by the coco api"""</span></span><br><span class="line"><span class="meta"> @torch.no_grad()</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, outputs, target_sizes)</span>:</span></span><br><span class="line"> </span><br><span class="line"> out_logits, out_bbox = outputs[<span class="string">'pred_logits'</span>], outputs[<span class="string">'pred_boxes'</span>]</span><br><span class="line"> <span class="keyword">assert</span> len(out_logits) == len(target_sizes)</span><br><span class="line"> <span class="keyword">assert</span> target_sizes.shape[<span class="number">1</span>] == <span class="number">2</span></span><br><span class="line"><span class="comment"># out_logits : (bs, num_queries,num_classes)</span></span><br><span class="line"> prob = F.softmax(out_logits, <span class="number">-1</span>)</span><br><span class="line"> scores, labels = prob[..., :<span class="number">-1</span>].max(<span class="number">-1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># convert to [x0, y0, x1, y1] format</span></span><br><span class="line"> boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)</span><br><span class="line"> <span class="comment"># and from relative [0, 1] to absolute [0, height] coordinates</span></span><br><span class="line"> img_h, img_w = target_sizes.unbind(<span class="number">1</span>)</span><br><span class="line"> scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=<span class="number">1</span>)</span><br><span class="line"> boxes = boxes * scale_fct[:, <span class="keyword">None</span>, :]</span><br><span class="line"></span><br><span class="line"> results = [{<span class="string">'scores'</span>: s, <span class="string">'labels'</span>: l, <span class="string">'boxes'</span>: b} <span class="keyword">for</span> s, l, b <span class="keyword">in</span> zip(scores, labels, boxes)]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> results</span><br></pre></td></tr></table></figure><h4 id="Loss-Fuction"><a href="#Loss-Fuction" class="headerlink" title="Loss Fuction"></a>Loss Fuction</h4><p>这一部分主要介绍一下和损失函数相关的部分源码。先看一下与损失函数相关的代码:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">matcher = build_matcher(args)</span><br><span class="line">weight_dict = {<span class="string">'loss_ce'</span>: <span class="number">1</span>, <span class="string">'loss_bbox'</span>: args.bbox_loss_coef}</span><br><span class="line">weight_dict[<span class="string">'loss_giou'</span>] = args.giou_loss_coef</span><br><span class="line"><span class="keyword">if</span> args.masks:</span><br><span class="line"> weight_dict[<span class="string">"loss_mask"</span>] = args.mask_loss_coef</span><br><span class="line"> weight_dict[<span class="string">"loss_dice"</span>] = args.dice_loss_coef</span><br><span class="line"><span class="comment"># TODO this is a hack</span></span><br><span class="line"><span class="keyword">if</span> args.aux_loss:</span><br><span class="line"> aux_weight_dict = {}</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> range(args.dec_layers - <span class="number">1</span>):</span><br><span class="line"> aux_weight_dict.update({k + <span class="string">f'_<span class="subst">{i}</span>'</span>: v <span class="keyword">for</span> k, v <span class="keyword">in</span> weight_dict.items()})</span><br><span class="line"> weight_dict.update(aux_weight_dict)</span><br><span class="line"></span><br><span class="line">losses = [<span class="string">'labels'</span>, <span class="string">'boxes'</span>, <span class="string">'cardinality'</span>]</span><br><span class="line"><span class="keyword">if</span> args.masks:</span><br><span class="line"> losses += [<span class="string">"masks"</span>]</span><br><span class="line">criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,</span><br><span class="line"> eos_coef=args.eos_coef, losses=losses)</span><br></pre></td></tr></table></figure><p>matcher是将预测结果与gt进行匹配的匈牙利算法,weight_dict是各部分loss设置的权重参数,包括分类与回归损失。分类使用的是CE loss,回归包括l1 loss和giou loss。如果包含分割任务,还有mask相关损失函数,另外如果设置了aux_loss,则代表计算decoder中间层预测结果对应的loss。 loss函数的实例化使用SetCriterion进行构建的。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">SetCriterion</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="string">""" This class computes the loss for DETR.</span></span><br><span class="line"><span class="string"> The process happens in two steps:</span></span><br><span class="line"><span class="string"> 1) we compute hungarian assignment between ground truth boxes and the outputs of the model</span></span><br><span class="line"><span class="string"> 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, num_classes, matcher, weight_dict, eos_coef, losses)</span>:</span></span><br><span class="line"> <span class="string">""" Create the criterion.</span></span><br><span class="line"><span class="string"> Parameters:</span></span><br><span class="line"><span class="string"> num_classes: number of object categories, omitting the special no-object category</span></span><br><span class="line"><span class="string"> matcher: module able to compute a matching between targets and proposals</span></span><br><span class="line"><span class="string"> weight_dict: dict containing as key the names of the losses and as values their relative weight.</span></span><br><span class="line"><span class="string"> eos_coef: relative classification weight applied to the no-object category</span></span><br><span class="line"><span class="string"> losses: list of all the losses to be applied. See get_loss for list of available losses.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.num_classes = num_classes</span><br><span class="line"> self.matcher = matcher</span><br><span class="line"> self.weight_dict = weight_dict</span><br><span class="line"> <span class="comment"># 针对背景分类的loss权重</span></span><br><span class="line"> self.eos_coef = eos_coef</span><br><span class="line"> self.losses = losses</span><br><span class="line"> empty_weight = torch.ones(self.num_classes + <span class="number">1</span>)</span><br><span class="line"> empty_weight[<span class="number">-1</span>] = self.eos_coef</span><br><span class="line"> self.register_buffer(<span class="string">'empty_weight'</span>, empty_weight)</span><br><span class="line"></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> '''</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">get_loss</span><span class="params">(self, loss, outputs, targets, indices, num_boxes, **kwargs)</span>:</span></span><br><span class="line"> loss_map = {</span><br><span class="line"> <span class="string">'labels'</span>: self.loss_labels,</span><br><span class="line"> <span class="string">'cardinality'</span>: self.loss_cardinality,</span><br><span class="line"> <span class="string">'boxes'</span>: self.loss_boxes,</span><br><span class="line"> <span class="string">'masks'</span>: self.loss_masks</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">assert</span> loss <span class="keyword">in</span> loss_map, <span class="string">f'do you really want to compute <span class="subst">{loss}</span> loss?'</span></span><br><span class="line"> <span class="keyword">return</span> loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, outputs, targets)</span>:</span></span><br><span class="line"> <span class="string">""" This performs the loss computation.</span></span><br><span class="line"><span class="string"> Parameters:</span></span><br><span class="line"><span class="string"> outputs: dict of tensors, see the output specification of the model for the format</span></span><br><span class="line"><span class="string"> targets: list of dicts, such that len(targets) == batch_size.</span></span><br><span class="line"><span class="string"> The expected keys in each dict depends on the losses applied, see each loss' doc</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> </span><br><span class="line"> outputs_without_aux = {k: v <span class="keyword">for</span> k, v <span class="keyword">in</span> outputs.items() <span class="keyword">if</span> k != <span class="string">'aux_outputs'</span>}</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Retrieve the matching between the outputs of the last layer and the targets</span></span><br><span class="line"> <span class="comment"># 将预测结果与GT进行匹配,indices是一个与bs长度相等的多元组的list</span></span><br><span class="line"> <span class="comment"># 每个元组为(ind_i,ind_j),前者是匹配的预测预测索引,后者是gt的索引</span></span><br><span class="line"> indices = self.matcher(outputs_without_aux, targets)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Compute the average number of target boxes accross all nodes, for normalization purposes</span></span><br><span class="line"> num_boxes = sum(len(t[<span class="string">"labels"</span>]) <span class="keyword">for</span> t <span class="keyword">in</span> targets)</span><br><span class="line"> num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)</span><br><span class="line"> <span class="keyword">if</span> is_dist_avail_and_initialized():</span><br><span class="line"> torch.distributed.all_reduce(num_boxes)</span><br><span class="line"> num_boxes = torch.clamp(num_boxes / get_world_size(), min=<span class="number">1</span>).item()</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Compute all the requested losses</span></span><br><span class="line"> <span class="comment"># 计算所有相关的损失,其中self.losses = ['labels', 'boxes', 'cardinality']</span></span><br><span class="line"> losses = {}</span><br><span class="line"> <span class="keyword">for</span> loss <span class="keyword">in</span> self.losses:</span><br><span class="line"> losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> '''</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="keyword">return</span> losses</span><br></pre></td></tr></table></figure><p>从forward函数可以看出,首先进行匈牙利匹配的是decoder最后一层的输出,之后再计算匹配后的损失函数包括losses = [‘labels’, ‘boxes’, ‘cardinality’],具体计算部分可以看get_loss方法中映射的对应计算方法,其中包括self.loss_labels,self.loss_cardinality,self.loss_boxes。</p><h4 id="匈牙利匹配"><a href="#匈牙利匹配" class="headerlink" title="匈牙利匹配"></a>匈牙利匹配</h4><p>匈牙利算法,在这里用于预测集(prediction set)和GT的匹配,<strong>最终匹配方案是选取“loss总和”最小的分配方式。</strong>注意,这里计算的loss与损失函数中计算loss并不相同,在这里是用来作为代价cost,cost大小决定匹配程度。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">HungarianMatcher</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, cost_class: float = <span class="number">1</span>, cost_bbox: float = <span class="number">1</span>, cost_giou: float = <span class="number">1</span>)</span>:</span></span><br><span class="line"> <span class="string">"""Creates the matcher</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string"> Params:</span></span><br><span class="line"><span class="string"> cost_class: This is the relative weight of the classification error in the matching cost</span></span><br><span class="line"><span class="string"> cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost</span></span><br><span class="line"><span class="string"> cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.cost_class = cost_class</span><br><span class="line"> self.cost_bbox = cost_bbox</span><br><span class="line"> self.cost_giou = cost_giou</span><br><span class="line"> <span class="keyword">assert</span> cost_class != <span class="number">0</span> <span class="keyword">or</span> cost_bbox != <span class="number">0</span> <span class="keyword">or</span> cost_giou != <span class="number">0</span>, <span class="string">"all costs cant be 0"</span></span><br><span class="line"></span><br><span class="line"><span class="meta"> @torch.no_grad()</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, outputs, targets)</span>:</span></span><br><span class="line"> bs, num_queries = outputs[<span class="string">"pred_logits"</span>].shape[:<span class="number">2</span>]</span><br><span class="line"></span><br><span class="line"> <span class="comment"># We flatten to compute the cost matrices in a batch</span></span><br><span class="line"> out_prob = outputs[<span class="string">"pred_logits"</span>].flatten(<span class="number">0</span>, <span class="number">1</span>).softmax(<span class="number">-1</span>) <span class="comment"># [batch_size * num_queries, num_classes]</span></span><br><span class="line"> out_bbox = outputs[<span class="string">"pred_boxes"</span>].flatten(<span class="number">0</span>, <span class="number">1</span>) <span class="comment"># [batch_size * num_queries, 4]</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># Also concat the target labels and boxes</span></span><br><span class="line"> tgt_ids = torch.cat([v[<span class="string">"labels"</span>] <span class="keyword">for</span> v <span class="keyword">in</span> targets])</span><br><span class="line"> tgt_bbox = torch.cat([v[<span class="string">"boxes"</span>] <span class="keyword">for</span> v <span class="keyword">in</span> targets])</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Compute the classification cost. Contrary to the loss, we don't use the NLL,</span></span><br><span class="line"> <span class="comment"># but approximate it in 1 - proba[target class].</span></span><br><span class="line"> <span class="comment"># The 1 is a constant that doesn't change the matching, it can be ommitted.</span></span><br><span class="line"> cost_class = -out_prob[:, tgt_ids]</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Compute the L1 cost between boxes</span></span><br><span class="line"> cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Compute the giou cost betwen boxes</span></span><br><span class="line"> cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))</span><br><span class="line"></span><br><span class="line"> <span class="comment"># Final cost matrix</span></span><br><span class="line"> C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou</span><br><span class="line"> C = C.view(bs, num_queries, <span class="number">-1</span>).cpu()</span><br><span class="line"></span><br><span class="line"> sizes = [len(v[<span class="string">"boxes"</span>]) <span class="keyword">for</span> v <span class="keyword">in</span> targets]</span><br><span class="line"> indices = [linear_sum_assignment(c[i]) <span class="keyword">for</span> i, c <span class="keyword">in</span> enumerate(C.split(sizes, <span class="number">-1</span>))]</span><br><span class="line"> <span class="keyword">return</span> [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) <span class="keyword">for</span> i, j <span class="keyword">in</span> indices]</span><br></pre></td></tr></table></figure><p>从上面可以看到,匈牙利匹配在前向计算过程中,是不需要梯度的。其中分类cost是直接采用1减去预测概率的形式,同时由于1是常数,于是作者甚至连1都省去了,在box上计算了l1和giou两种cost,之后对各部分进行加权求和得到总的cost。匹配方法使用的是 <em>scipy</em> 优化模块中的 <em>linear_sum_assignment()</em>,其输入是二分图的度量矩阵,该方法是计算这个二分图度量矩阵的最小权重分配方式,返回的是匹配方案对应的矩阵行索引和列索引。</p><h4 id="End"><a href="#End" class="headerlink" title="End"></a>End</h4><p>至此,DETR所有相关源码均已解读完毕。</p>]]></content>
<summary type="html">
<p>transformer由encoder和decoder俩部分组成。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="transformer" scheme="https://blog.nicehuster.cn/tags/transformer/"/>
</entry>
<entry>
<title>pix2seq方法详解</title>
<link href="https://blog.nicehuster.cn/2022/07/18/pix2seq/"/>
<id>https://blog.nicehuster.cn/2022/07/18/pix2seq/</id>
<published>2022-07-18T11:13:39.000Z</published>
<updated>2022-07-19T02:46:10.103Z</updated>
<content type="html"><![CDATA[<p>本文分享seq2seq learning相关的两篇论文,单位是google brain,一作均为Ting Chen(自监督学习方法SimCLR的作者),论文地址:<a href="https://arxiv.org/pdf/2109.10852.pdf" target="_blank" rel="noopener">pix2seq: A Language Modeling Framework for Object Detection</a>,[ICLR2022接收];<a href="https://arxiv.org/pdf/2206.07669.pdf" target="_blank" rel="noopener">A Unified Sequence Interface for Vision Tasks</a>,[上星期挂arxiv],后者是对前者在多个视觉任务上的拓展。下面大致的介绍一下两篇论文的具体工作。</p><a id="more"></a><p><img src="/img/image-20220718111615345.png" alt="image-20220718111615345"></p><h3 id="Pix2seq"><a href="#Pix2seq" class="headerlink" title="Pix2seq"></a>Pix2seq</h3><h4 id="主要方法"><a href="#主要方法" class="headerlink" title="主要方法"></a>主要方法</h4><p>pix2seq将目标检测任务转换为语言建模来处理,以往目标检测任务基于属性进行预测,通常会分为分类预测+回归预测。而该方法是通过将object的类别和box信息构造成序列,对序列进行预测。pix2seq的结构和学习过程由四个部分组成,如下图所示:</p><p><img src="/img/image-20220718112047122.png" alt="image-20220718112047122"></p><p>Pix2seq主要包含四部分:</p><blockquote><ol><li>Image augmentation:图像数据增强,包括random scale+crops;</li><li>Sequence construction & augmentation:构造序列和序列增强,将bbox和label转换为离散的token;</li><li>Architecture:使用encoder-decoder结构将输入图像pixel转换为序列;</li><li>Objective/loss function:常用的softmax 交叉熵损失;</li></ol></blockquote><h4 id="序列化"><a href="#序列化" class="headerlink" title="序列化"></a>序列化</h4><p><img src="/img/image-20220718112842096.png" alt="image-20220718112842096"></p><p>这里详细介绍一下box+cls序列化方法,为了和自然语言对齐,它把坐标框(4个值)和类别(1)都拼成一个序列,意味着100个目标对应着长度为500的序列。因为坐标是连续值,作者这里用了一个分桶的机制,把坐标分到n个桶里(bin),就构成了离散值。具体地,一个目标被表示为一个由五个离散的[token]组成的序列,即[ymin, xmin, ymax, xmax, c],其中每个连续的角坐标被均匀地离散为[1, nbins]之间的一个整数,c是类索引。我们对所有标记使用共享词汇表,因此词汇量大小等于 bin 数+类别数。对于600x600的图片而言,使用600个bin就可以实现零量化误差,其实整个离散值的范围比起nlp里的字典而言,还是非常非常小的。</p><h4 id="实验结果"><a href="#实验结果" class="headerlink" title="实验结果"></a>实验结果</h4><p>在构建好序列之后,使用Resnet + 6层transformer encoder + 6层transformer decoder对输入图像进行序列化,然后使用交叉熵计算损失。作者分别在train from scratch和finetune两种setting下进行了一些实验对比。</p><p><img src="/img/image-20220718211018875.png" alt="image-20220718211018875"></p><p>从上面结果来看,相比较而言,在指标上优势并不明显,但足矣证明本文的idea是可行的。在train from scratch的setting下,pix2seq是训练了300epoch,表格上之所以并没表明对比方法训练的epoch数,可能这正是pix2seq的一个缺点,训练收敛慢。</p><h3 id="PixSeq-v2"><a href="#PixSeq-v2" class="headerlink" title="PixSeq v2"></a>PixSeq v2</h3><h4 id="主要方法-1"><a href="#主要方法-1" class="headerlink" title="主要方法"></a>主要方法</h4><p>Pixseq v2是上周Ting Chen挂在arxiv的对pix2seq在多个视觉任务上拓展的一个工作,总的来说,作者并没有对模型层面做进一步改进,但对不同视觉task的输入输出接口做了统一。如下图所示,</p><p><img src="/img/image-20220719094051172.png" alt="image-20220719094051172"></p><p>以往的视觉任务比如,目标检测、实例分割、关键点检测和图像描述等任务都是单独设计不同模型、不同输入、不同损失函数来解决,而本文将每个任务的输出形式化为具有一个统一接口的一个离散的token序列,可以做到在所有这些任务上仅训练一个具有单一模型结构和损失函数的神经网络,而不需要针对特定任务进行模型结构或损失函数的定制。为了解决一个特定的任务,本文使用一个简短的prompt作为该任务的描述,网络的输出序列适应于该prompt,因此模型能够产生特定于任务的输出。</p><blockquote><ol><li>对于目标检测任务,遵循pix2seq做法,通过量化连续图像坐标,将box和cls转换为一系列离散token;</li><li>对于实例分割任务,以图像坐标序列形式预测polygon,与检测任务一样,对坐标进行量化离散为token;</li><li>对于关键点预测任务,给定一个人体实例,将关键点预测为一个量化的图像坐标序列;</li><li>对于图像描述,直接预测文本token。</li></ol></blockquote><p><img src="/img/image-20220719094807354.png" alt="image-20220719094807354"></p><p>值得注意的是,所有四个任务都使用同一个词汇表。 具体的prompt和输出序列如上图所示</p><h4 id="训练"><a href="#训练" class="headerlink" title="训练"></a>训练</h4><p>每个任务都有自己的成对图像序列训练数据。有两种方法可以将任务结合起来进行联合训练。作者提出了data mixing和batch mixing两种数据混合方式。</p><p><img src="/img/image-20220719095713926.png" alt="image-20220719095713926"></p><p>data mixing在概念上简单,但是因为数据格式不同,图像增强很难合并比较麻烦,相比较而言,batch mixing对单个任务采样图像后进行相应增强后转换为图像-序列对,模型分别计算每个任务的损失和梯度。作者认为可以将特定任务的每一批数据的梯度以适当的形式加权组合起来。</p><p><img src="/img/image-20220719100348972.png" alt="image-20220719100348972"></p><p>在损失函数上,与pix2seq一样,训练目标是最大化基于图像的token和之前的token的似然性。</p><script type="math/tex; mode=display">\operatorname{maximize} \sum_{j=1}^{L} \boldsymbol{w}_{j} \log P\left(\boldsymbol{y}_{j} \mid \boldsymbol{x}, \boldsymbol{y}_{1: j-1}\right)</script><p>其中,x表示输入图像,y是长度为L的编码序列(监督信号),序列y的初始部分是一个prompt,为此作者将权重wi设置为零,损失计算时不包括该部分。</p><h4 id="实验结果-1"><a href="#实验结果-1" class="headerlink" title="实验结果"></a>实验结果</h4><p><img src="/img/image-20220719103500099.png" alt="image-20220719103500099"></p><p>从上述表格可以看出,在模型结构和损失函数都没有针对特定任务进行设计的前提下,本文所提出的模型对于每个单独的任务仍然可以获得与专门定制化的baseline相比,依然具有一定的可比性(即使输入图像的尺寸更小)。</p>]]></content>
<summary type="html">
<p>本文分享seq2seq learning相关的两篇论文,单位是google brain,一作均为Ting Chen(自监督学习方法SimCLR的作者),论文地址:<a href="https://arxiv.org/pdf/2109.10852.pdf" target="_blank" rel="noopener">pix2seq: A Language Modeling Framework for Object Detection</a>,[ICLR2022接收];<a href="https://arxiv.org/pdf/2206.07669.pdf" target="_blank" rel="noopener">A Unified Sequence Interface for Vision Tasks</a>,[上星期挂arxiv],后者是对前者在多个视觉任务上的拓展。下面大致的介绍一下两篇论文的具体工作。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="seq2seq" scheme="https://blog.nicehuster.cn/tags/seq2seq/"/>
</entry>
<entry>
<title>ssh免密登陆(精简命令行登陆)</title>
<link href="https://blog.nicehuster.cn/2022/06/28/ssh%E5%85%8D%E5%AF%86%E7%99%BB%E9%99%86/"/>
<id>https://blog.nicehuster.cn/2022/06/28/ssh免密登陆/</id>
<published>2022-06-28T11:13:39.000Z</published>
<updated>2022-09-09T08:53:19.128Z</updated>
<content type="html"><![CDATA[<p>这篇文章简要记录一下免密登陆服务器的具体设置过程。</p><a id="more"></a><p>每次ssh登陆服务器都需要输入一串字符,还要输入密码比较繁琐。如下:</p><figure class="highlight css"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="selector-tag">ssh</span> <span class="selector-tag">username</span>@<span class="keyword">192</span>.<span class="keyword">168</span>.<span class="keyword">1</span>.<span class="keyword">100</span></span><br></pre></td></tr></table></figure><p>常用的登录命令形式,之后还需要输入密码验证。麻烦。如何才能简化呢。方法如下:</p><p>第一步:简化登陆命令行,效果如下:</p><figure class="highlight lsl"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">ssh <span class="number">100</span> <=等效于=> ssh username@<span class="number">192.168</span><span class="number">.1</span><span class="number">.100</span></span><br></pre></td></tr></table></figure><p>方法如下:修改~/.ssh/config (如果没有.ssh或者config,就新建一个)</p><figure class="highlight routeros"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">Host 100</span><br><span class="line">HostName 192.168.1.100</span><br><span class="line">Port 22</span><br><span class="line">User username</span><br></pre></td></tr></table></figure><p>保存后,输入:ssh 100 就可以等了服务器了,但是还是需要输入密码。</p><p>第二步:实现免密码登录</p><p>ssh常用公钥和私钥的方式实现免密码登录,在你安装ssh后,自带了一个ssh-genkey的工具生成公钥和私钥。设置方法如下:</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">test</span>@ubuntu:~/.ssh$ ls</span><br><span class="line">config</span><br><span class="line"><span class="built_in">test</span>@ubuntu:~/.ssh$</span><br><span class="line"><span class="built_in">test</span>@ubuntu:~/.ssh$ ssh-keygen</span><br><span class="line">Generating public/private rsa key pair.</span><br><span class="line">Enter file <span class="keyword">in</span> <span class="built_in">which</span> to save the key (/home/yaolan/.ssh/id_rsa): id_rsa (输入保存的文件名称)</span><br><span class="line">Enter passphrase (empty <span class="keyword">for</span> no passphrase): (输入Enter键)</span><br><span class="line">Enter same passphrase again: (输入Enter键)</span><br><span class="line">Your identification has been saved <span class="keyword">in</span> id_rsa.</span><br><span class="line">Your public key has been saved <span class="keyword">in</span> id_rsa.pub.</span><br><span class="line">The key fingerprint is:</span><br><span class="line">14:b5:e4:73:1a:c7:95:d1:f4:86:3e:0c:6d:6e:cc:ef yaolan@VirtualBox</span><br><span class="line">The key<span class="string">'s randomart image is:</span></span><br><span class="line"><span class="string">+--[ RSA 2048]----+</span></span><br><span class="line"><span class="string">| ..o o=.|</span></span><br><span class="line"><span class="string">| + o o..o|</span></span><br><span class="line"><span class="string">| . = = + o|</span></span><br><span class="line"><span class="string">| . * O . |</span></span><br><span class="line"><span class="string">| S . O |</span></span><br><span class="line"><span class="string">| . o |</span></span><br><span class="line"><span class="string">| .|</span></span><br><span class="line"><span class="string">| . |</span></span><br><span class="line"><span class="string">| E|</span></span><br><span class="line"><span class="string">+-----------------+</span></span><br><span class="line"><span class="string">test@ubuntu:~/.ssh$ ls</span></span><br><span class="line"><span class="string">config id_rsa id_rsa.pub</span></span><br></pre></td></tr></table></figure><p>id_rsa私钥,id_rsa.pub公钥,采用RSA加密形式。我们只要把 id_rsa.pub里面的公钥添加到服务器上的~/.ssh/authorized_keys文件中即可。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">cat ./id_rsa.pub >> ~/.ssh/authorized_keys</span><br></pre></td></tr></table></figure><p>完成这步,我们就可以免密码等了</p>]]></content>
<summary type="html">
<p>这篇文章简要记录一下免密登陆服务器的具体设置过程。</p>
</summary>
<category term="project" scheme="https://blog.nicehuster.cn/categories/project/"/>
<category term="linux" scheme="https://blog.nicehuster.cn/tags/linux/"/>
</entry>
<entry>
<title>ViLD基于CLIP模型的zero-shot目标检测方法</title>
<link href="https://blog.nicehuster.cn/2022/06/13/ViLD/"/>
<id>https://blog.nicehuster.cn/2022/06/13/ViLD/</id>
<published>2022-06-13T11:13:39.000Z</published>
<updated>2022-06-15T03:23:56.100Z</updated>
<content type="html"><![CDATA[<p><strong>论文信息</strong>:<a href="https://arxiv.org/pdf/2104.13921.pdf" target="_blank" rel="noopener">Open-vocabulary Object Detection via Vision and Language Knowledge Distillation</a><br><strong>代码链接</strong>:<a href="https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild" target="_blank" rel="noopener">https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild</a><br><strong>整体信息</strong>:这是google research 发表在ICLR2022上有关CLIP在下游任务-目标检测任务上的应用。使用CLIP模型实现zero-shot场景下的目标检测任务。比较有想象意义的是,通过一句话就可以检测出图像中需要的指定目标。在之前<a href="https://nicehuster.github.io/2022/06/03/CLIP/#more" target="_blank" rel="noopener">CLIP图文多模态对比预训练方法详解</a>中也有提及过这篇工作。</p><a id="more"></a><p><img src="https://camo.githubusercontent.com/238affd721b73ef2ff8f95114352188693623796bcc57b83f94747a7d2b6ff3a/68747470733a2f2f73746f726167652e676f6f676c65617069732e636f6d2f636c6f75642d7470752d636865636b706f696e74732f646574656374696f6e2f70726f6a656374732f76696c642f6173736574732f6e65775f7465617365722e706e67" alt></p><p>上图展示的是ViLD的检测结果。其中只有toy类别是训练过程中见到的类别,但zero-shot detection还检测到其他的属性,比如toy种类和颜色等。</p><h3 id="zero-shot-detection"><a href="#zero-shot-detection" class="headerlink" title="zero-shot detection"></a>zero-shot detection</h3><p>顾名思义,这个任务就是,对于任意一个新类别,一张训练图像都不给的情况下,训练出来的检测器也能检测这个类别。zero-shot detection的setting通常是,将数据集依据类别划分为俩部分:base类别和novel类别,base类别用于训练,novel类别在训练过程中不可见。该任务的目标在于,在novel类别上获得较好性能的同时还需要保持base类别的性能;在这篇文章中使用的是LVIS数据集进行实验对比分析,将common和frequency俩类别作为base类,将rare类别作为novel类。</p><h3 id="常规方法"><a href="#常规方法" class="headerlink" title="常规方法"></a>常规方法</h3><p><img src="/img/vild-Vanilla.png" alt="vild-Vanilla"></p><p>如上图展示的是zero-shot detection with cropped regions,具体地,使用在二阶段检测方法比如Mask-RCNN获得proposal之后,对每个proposal都crop & resize 然后输入到CLIP-image-encoder中获得image-embedding,与对应类别的text-embedding进行对比,获取类别信息。该方法的的缺点是比较慢,需要one-by-one地处理每个object proposal,而且CLIP-text-encoder没有充分利用base类别的文本信息。</p><h3 id="ViLD方法"><a href="#ViLD方法" class="headerlink" title="ViLD方法"></a>ViLD方法</h3><p><img src="/img/vild-pipeline.png" alt="vild-pipeline"></p><p>上图展示的是ViLD方法的pipeline。具体地,在ViLD中包含俩部分:<strong>ViLD-text</strong>用于学习文本embedding和<strong>ViLD-image</strong>用于学习图像embedding。在ViLD-text中,将base类别文本送入CLIP-text-encoder中获得text embedding,然后用于classify目标区域,在ViLD-image中会将对应的proposal送入CLIP-image-encoder中获得图像embedding,对经过roi align之后的region embedding 进行知识蒸馏;相比于ViLD-text,ViLD-image蒸馏了base+novel的信息,因为proposal网络输出的proposal可能会包含novel,而ViLD-text只使用了base类的文本信息;</p><p><img src="/img/vild-overview.png" alt="vild-overview"></p><p>上图展示的是ViLD的训练和推理流程。相比于mask-rcnn,修改地是rcnn的分类分支;具体地,在训练过程中,在获取分类监督信号上包括俩部分:用CLIP获得image embedding蒸馏region embedding,以及用CLIP获得text embedding监督region embedding;总的损失如下公式所示:</p><script type="math/tex; mode=display">\mathcal{L}_{\mathrm{ViLD}}=\mathcal{L}_{\text {ViLD-text }}+w \cdot \mathcal{L}_{\mathrm{ViLD}-\mathrm{image}}</script><p>在推理过程,只需要将region embedding和text embedding(base+novel)进行对比即可得到类别信息。</p><h3 id="实验"><a href="#实验" class="headerlink" title="实验"></a>实验</h3><p><strong>数据集</strong>:实验使用的是LVIS v1.0(1203类别),其中frequent(f: 405个类别)和common(c: 461个类别)作为base类,其余rare(r: 337个类别)作为novel类。</p><p><strong>目标proposal</strong>:由于训练过程中只使用了base类训练,下表展示的是仅使用base类训练时的RPN召回率和使用base+novel时的RPN召回率,从上可以看出二者相差1-2个点。因此可以看出RPN是具备从base类上泛化到novel。</p><p><img src="/img/vild-recall.png" alt="vild-recall"></p><p><strong>Ablation</strong>:作者在paper中做了较为详尽的ablation study实验,这里只提及一些证明idea有效的关键实验分析。</p><p><img src="/img/vild-ensemble.png" alt="vild-ensemble"></p><p>上表格中,CLIP on cropped regions就是前面介绍的常规方法,该方法在APr上可以达到13.0,ViLD-text和ViLD-image表示分别使用单一监督信号。ViLD(w=0.5)表示同时使用ViLD-text和ViLD-image监督训练。ViLD-text相比CLIP on cropped regions在APr上下降了3个点,说明使用base类信息监督ViLD-text在novel上的泛化性有所下降。ViLD(w=0.5)相比于ViLD-text和ViLD-image都提升幅度明显。ViLD-ensemble(w=0.5)表示同时使用ViLD-text和ViLD-image监督训练同时,在base预测上,倾向于ViLD-text,在novel预测上使用vice versa投票决定。可以看出ViLD-ensemble(w=0.5)方式在base类别上提升明显。</p><p><strong>Transfer to other detection datasets</strong>:这个是证明在不同数据集之间的一个迁移有效性。只需要替换类别 text embedding,无需进行fine-tune。</p><p><img src="/img/vild-transfer.png" alt="vild-transfer"></p><p>在不进行任何fine-tune下,ViLD在COCO数据集上就可以取得36.6AP,与fine-tune条件下AP只相差不到3个点。</p><h3 id="最后"><a href="#最后" class="headerlink" title="最后"></a>最后</h3><p>作者也在离线交互式检测上也做过一些实验,输入文本信息,就可以检测出对应的目标。这个还挺有意思的,随意说一句话就能检测到图像的指定目标。</p><p><img src="/img/vild-interactive detection.png" alt="vild-interactive detection"></p>]]></content>
<summary type="html">
<p><strong>论文信息</strong>:<a href="https://arxiv.org/pdf/2104.13921.pdf" target="_blank" rel="noopener">Open-vocabulary Object Detection via Vision and Language Knowledge Distillation</a><br><strong>代码链接</strong>:<a href="https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild" target="_blank" rel="noopener">https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild</a><br><strong>整体信息</strong>:这是google research 发表在ICLR2022上有关CLIP在下游任务-目标检测任务上的应用。使用CLIP模型实现zero-shot场景下的目标检测任务。比较有想象意义的是,通过一句话就可以检测出图像中需要的指定目标。在之前<a href="https://nicehuster.github.io/2022/06/03/CLIP/#more" target="_blank" rel="noopener">CLIP图文多模态对比预训练方法详解</a>中也有提及过这篇工作。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
<entry>
<title>DeCLIP一种数据高效的CLIP训练方法</title>
<link href="https://blog.nicehuster.cn/2022/06/09/DeCLIP/"/>
<id>https://blog.nicehuster.cn/2022/06/09/DeCLIP/</id>
<published>2022-06-09T11:13:39.000Z</published>
<updated>2022-06-13T03:56:00.801Z</updated>
<content type="html"><![CDATA[<p><strong>论文信息</strong>:<a href="https://arxiv.org/pdf/2110.05208.pdf" target="_blank" rel="noopener">Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm</a><br><strong>代码链接</strong>:<a href="https://github.com/Sense-GVT/DeCLIP" target="_blank" rel="noopener">https://github.com/Sense-GVT/DeCLIP</a><br><strong>整体信息</strong>:这是商汤科技发表在ICLR2022上关于多模态预训练的工作,在前面的文章中介绍过CLIP,是一种基于对比文本-图像对的预训练方法,该方法需要在大量的图像-文本对数据集进行训练,在CLIP工作上就使用了4亿的图像-文本对数据,数百张卡进行预训练。为了提高训练效率,这篇工作提出了DeCLIP(Data Efficiency CLIP)方法,在较少数据下依旧可以取得不错的效果。</p><a id="more"></a><p><img src="/img/declip-sota.png" alt="declip-sota"></p><h3 id="具体方法"><a href="#具体方法" class="headerlink" title="具体方法"></a>具体方法</h3><p><img src="/img/clip-vs-declip.png" alt="clip-vs-declip"></p><p>上图,直观地,展示的是CLIP和DeCLIP方法的差异。CLIP是直接学习原始图片与对应的文本信息,使用俩个encoder分别编码图像信息和文本信息。图像encoder一般是resnet或者ViT,文本encoder一般使用transformer。之后将俩个embedding映射到相同的空间中,使用对比学习的思想进行训练。从方法上看,其实只使用了图像-文本对匹配的一种监督信号进行训练。假设batch size是N,共计N个图像-文本对$\left{\left(x<em>{i}^{I}, x</em>{i}^{T}\right)\right}$,损失函数InfoNCE如下:</p><script type="math/tex; mode=display">L_{I}=-\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}^{I}, \boldsymbol{z}_{i}^{T}\right) / \tau\right)}{\sum_{j=1}^{N} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}^{I}, \boldsymbol{z}_{j}^{T}\right) / \tau\right)}</script><p>不同于CLIP,在DeCLIP方法中,使用了更多的自监督信号:1. 单模态的自监督学习;2. 跨模态的多视角监督学习;3. 最近邻监督学习;具体地,</p><ol><li><p><strong>单模态自监督学习</strong>(self-supervision within each modality, SS),包括使用<strong>SimSiam</strong>作为图像的自监督信号,和使用掩码语言模型<strong>MLM</strong>作为文本的自监督信号;</p><p><img src="/img/declip-ss.png" alt="declip-ss"></p><p>(a)<strong>图像自监督</strong>:同一张图片进过数据增强获得俩个view:$(x^{I}, \tilde{x}^{I})$,将经过数据增强后的结果经过相同的encoder得到俩个embedding向量$(z^{I}, \tilde{z}^{I})$,之后将其中一个embedding向量$x^{I}$再经过一个perd层得到向量$p^{I}$,训练时让$p^{I}$和$\tilde{x}^{I}$ 尽量接近;</p><p>(b)<strong>文本自监督</strong>:文本自监督使用的是MLM方法,即随机mask掉文本中15%的token,然后利用前后token预测被mask掉的token;</p></li><li><p><strong>跨模态多视角监督学习</strong>(Multi-View Supervision, MVS):CLIP只使用的原始图像-文本对$\left(z^{I}, z^{T}\right)$,计算infoNCE损失,而DeCLIP中使用的是增强后的文本和图像计算infoNCE损失:$\left(z^{I}, z^{T}\right), \quad\left(\tilde{z}^{I}, z^{T}\right),\left(z^{I}, \tilde{z}^{T}\right), \quad\left(\tilde{z}^{I}, \tilde{z}^{T}\right)$ ,相比CLIP多了3个监督信息;</p></li><li><p><strong>最近邻监督学习</strong>(Nearest-Neighbor Supervision, NNS):考虑到相同的图像可能会有类似的语言描述,因此选择语言描述相似的图文进行对比学习,通过维护一个先入先出的队列来模拟整个数据的分布,从队列中选择最相似的句子作为正样本$z^{T^{\prime}}$,之后使用InfoNCE计算最近邻损失:$\left(z^{I}, z^{T^{\prime}}\right),\left(\tilde{z}^{I}, z^{T^{\prime}}\right)$;</p><p><img src="/img/declip-nss.png" alt="declip-nss"></p></li></ol><p>在损失函数层面上,对以上三种不同监督的损失进行加权求和,得到最终的loss,具体地,如下所示:</p><script type="math/tex; mode=display">L_{D e C L I P}=(1-\alpha-\beta-\gamma) L_{C L I P}+\alpha\left(L_{I S S}+L_{T S S}\right)+\beta L_{M V S}+\gamma L_{N N S}</script><h3 id="数据集"><a href="#数据集" class="headerlink" title="数据集"></a>数据集</h3><p><img src="/img/declip-dataset.png" alt="declip-dataset"></p><p>在DeCLIP中,数据集包含俩部分:开源数据集29M和网络下载的数据集59M,总共88M训练数据,相比于CLIP使用的400M数据少很多。</p><h3 id="实验"><a href="#实验" class="headerlink" title="实验"></a>实验</h3><ol><li><p>Zero-shot准确率;</p><p><img src="/img/declip-zero-shot.png" alt="declip-zero-shot"></p><p>相比于CLIP,使用更少的训练数据,得到了更高的准确率;</p></li><li><p>下游任务表现;</p><p><img src="/img/declip-finetune.png" alt="declip-finetune"></p><p>在resnet和ViT俩种不同的encoder上,都证明了DeCLIP学习到的特征表示相比CLIP要强;</p></li><li><p>Ablation study</p><p><img src="/img/declip-ablation.png" alt="declip-ablation"></p><p>如上图证明了使用多种监督信息可有效的提升zero-shot准确率,而且相比于CLIP,DeCLIP的训练效率更高;</p></li></ol><h3 id="最后"><a href="#最后" class="headerlink" title="最后"></a>最后</h3><p>作者还在DeCLIP的基础上提出了<a href="https://arxiv.org/abs/2203.05796" target="_blank" rel="noopener">CLIP-benchmark</a>,其中包含了高质量的YFCC15M-V2数据集,而且复现了CLIP系列的相关方法(CLIP,DeCLIP,FILIP,DeCLIP,DeFILIP)。目前代码均已开源在<a href="https://github.com/Sense-GVT/DeCLIP" target="_blank" rel="noopener">这里</a>。</p>]]></content>
<summary type="html">
<p><strong>论文信息</strong>:<a href="https://arxiv.org/pdf/2110.05208.pdf" target="_blank" rel="noopener">Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm</a><br><strong>代码链接</strong>:<a href="https://github.com/Sense-GVT/DeCLIP" target="_blank" rel="noopener">https://github.com/Sense-GVT/DeCLIP</a><br><strong>整体信息</strong>:这是商汤科技发表在ICLR2022上关于多模态预训练的工作,在前面的文章中介绍过CLIP,是一种基于对比文本-图像对的预训练方法,该方法需要在大量的图像-文本对数据集进行训练,在CLIP工作上就使用了4亿的图像-文本对数据,数百张卡进行预训练。为了提高训练效率,这篇工作提出了DeCLIP(Data Efficiency CLIP)方法,在较少数据下依旧可以取得不错的效果。</p>
</summary>
<category term="paper reading" scheme="https://blog.nicehuster.cn/categories/paper-reading/"/>
<category term="multi-model" scheme="https://blog.nicehuster.cn/tags/multi-model/"/>
</entry>
</feed>