Skip to content

Commit

Permalink
Update Transformer-NAR.md
Browse files Browse the repository at this point in the history
  • Loading branch information
OctoberFox11 authored Jun 17, 2024
1 parent d029d3f commit f4716b9
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions Transformer-NAR.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,23 @@ NAR通常是指构建能够捕捉算法计算的神经网络的艺术。通过

# 本节描述了我们的混合TransNAR架构

本节描述了我们的混合TransNAR架构(参见图fig:architecture)。TransNAR接受双重输入,包括文本算法问题规范(包含T个标记)及其相应的CLRS-30特定图形表示(包含N个节点),并输出问题的文字回应。我们可以假设,一旦编码,文本输入存储在**T**一个T×k的实数矩阵)中,图形输入存储在**G**一个N×l的实数矩阵)中。注意,为了简化下面的方程,我们假设所有与问题图形版本相关的信息都存储在节点中——这在CLRS-30中通常不是真的(也可能有边和图级别的输入),但这不会改变下面介绍的基本数据流。
本节描述了我们的混合TransNAR架构(参见图fig:architecture)。TransNAR接受双重输入,包括文本算法问题规范(包含T个标记)及其相应的CLRS-30特定图形表示(包含N个节点),并输出问题的文字回应。我们可以假设,一旦编码,文本输入存储在**T**一个 \( T \times k \) 的实数矩阵)中,图形输入存储在**G**一个 \( N \times l \) 的实数矩阵)中。注意,为了简化下面的方程,我们假设所有与问题图形版本相关的信息都存储在节点中——这在CLRS-30中通常不是真的(也可能有边和图级别的输入),但这不会改变下面介绍的基本数据流。

TransNAR的前向传递如下展开。首先,我们通过设置**T(0) = T****G(0) = G**正确初始化输入。接下来,为了计算步骤(t+1)的表示,文本(标记)表示被送入Transformer的当前层:
```latex
TransNAR的前向传递如下展开。首先,我们通过设置 \( \mathbf{T}^{(0)} = \mathbf{T} \)\( \mathbf{G}^{(0)} = \mathbf{G} \) 正确初始化输入。接下来,为了计算步骤 \( (t+1) \) 的表示,文本(标记)表示被送入Transformer的当前层:
$$
\mathbf{\Theta}^{(t+1)} = \text{FFN}\left(\text{softmax}\left(\frac{(\mathbf{T}^{(t)}\mathbf{Q}_t)^\top\mathbf{T}^{(t)}\mathbf{K}_t}{\sqrt{d_k}}\right)\mathbf{T}^{(t)}\mathbf{V}_t\right)
$$
```
其中**Q_t**, **K_t**是键和查询矩阵,**V_t**是值矩阵,FFN是前馈网络。类似地,图形表示被送入NAR层,实现例如标准max-MPNN:
其中 \( \mathbf{Q}_t, \mathbf{K}_t \in \mathbb{R}^{k \times d_k}, \mathbf{V}_t \in \mathbb{R}^{k \times k} \) 分别是键、查询和值转换,而 FFN 是一个前馈网络。类似地,图形表示被送入 NAR 层,实现例如标准 max-MPNN:
$$
\mathbf{g}^{(t+1)}_u = \phi\left(\mathbf{g}^{(t)}_u, \max_{1\leq v\leq N}\psi\left(\mathbf{g}^{(t)}_u, \mathbf{g}^{(t)}_v\right)\right)
\mathbf{g}^{(t+1)}_u = \phi\left(\mathbf{g}^{(t)}_u, \max_{1 \leq v \leq N}\psi\left(\mathbf{g}^{(t)}_u, \mathbf{g}^{(t)}_v\right)\right)
$$
其中ψ和φ是可学习的消息和更新函数,max是逐元素最大聚合。注意方程只提供节点之间的成对交互——实际上,我们的NAR是一个Triplet-GMPNN,也包含三元组交互和一个门控机制。进一步注意,NAR的可学习部分没有时间步索引——每一步,应用的是共享函数。这与算法计算在图形上的迭代、重复性质很好地对齐。
其中 \( \psi, \phi : \mathbb{R}^k \times \mathbb{R}^k \rightarrow \mathbb{R}^k \) 分别是可学习的消息和更新函数,而 max 是逐元素最大聚合。注意方程只提供节点之间的成对交互——实际上,我们的 NAR 是一个 Triplet-GMPNN,也包含三元组交互和一个门控机制。进一步注意,NAR 的可学习部分没有时间步索引——每一步,应用的是共享函数。这与算法计算在图形上的迭代、重复性质很好地对齐。

一旦两个流都准备好了它们的表示**Θ(t+1)****G(t+1)**,图中的节点嵌入就调节Transformer的标记嵌入,产生Transformer流中TransNAR块的最终结果,灵感来自Flamingo
一旦两个流都准备好了它们的表示 \( \mathbf{\Theta}^{(t+1)} \)\( \mathbf{G}^{(t+1)} \),图中的节点嵌入就调节 Transformer 的标记嵌入,产生 Transformer 流中 TransNAR 块的最终结果,灵感来自 Flamingo
$$
\mathbf{T}^{(t+1)} = \text{FFN}\left(\text{softmax}\left(\frac{(\mathbf{\Theta}^{(t)}\mathbf{Q}^\times_t)^\top\mathbf{G}^{(t)}\mathbf{K}^\times_t}{\sqrt{d_k}}\right)\mathbf{G}^{(t)}\mathbf{V}^\times_t\right)
$$
其中**Q_t^\times**, **K_t^\times**是交叉注意力的键和查询矩阵,**V_t^\times**是值矩阵。在结束这一层之前,不执行**G(t+1)**的额外转换。

其中 \( \mathbf{Q}_t^\times, \mathbf{K}_t^\times \in \mathbb{R}^{k \times d_k}, \mathbf{V}_t^\times \in \mathbb{R}^{k \times k} \) 分别是交叉注意力的键、查询和值转换。在结束这一层之前,不执行 \( \mathbf{G}^{(t+1)} \) 的额外转换。
这个过程一直重复,直到最终的第$N_l$层,当最终的文本输出从${\bf T}^{(N_l)}$中读取出来。最终输出通过最终层产生的预测头转换为标记logits,我们通过标准的下一个标记预测目标来监督。

在TransNAR微调开始之前,我们预先训练NAR以稳健地执行CLRS-30涵盖的三十个算法,类似于generalist。众所周知,这样的程序能够在图形空间中实现高达4倍的输入尺寸的分布外泛化。NAR的参数在微调期间通常保持**冻结**,因为额外的梯度会消除模型原始的鲁棒性属性。这也是为什么图形嵌入不执行交叉注意力的原因。LLM本身可能在大规模数据集上预先训练过,以建立其一般语言先验,尽管即使LM最初是随机初始化的,我们也恢复了相同的实验结果。
Expand Down

0 comments on commit f4716b9

Please sign in to comment.