拉格朗月

基于相对位置编码的自注意力机制

2021-07-17

在Transformer模型中没有显示地在模型结构上建模输入序列地绝对或者相对位置信息,而是通过位置编码地方式,将每个位置用一个向量来表示,然后与元素对应地词向量相加来使得模型可以感知元素地位置信息。在论文《Self-Attention with Relative Position Representations》中,作者提出了相对位置编码,对attention机制进行了扩展,使得模型可以感知输入序列不同元素之间的相对位置或者说是相对距离。

Paper Info

论文:《Self-Attention with Relative Position Representations》

arxiv: https://arxiv.org/pdf/1803.02155.pdf

动机

对于RNN一类的递归类模型,模型可以从其序列结构中学习到有关绝对或者相对位置信息,因此RNN适合用于序列建模任务。而对于Transformer模型,其在模型结构上没有显式的对位置进行建模,而是对每个位置用一个位置向量来表示位置信息,通过将位置向量表示与token向量表示相加,从而让模型感知到位置信息。

这篇文章则是在自注意力机制上,加上相对位置的信息,让Transformer这类基于自注意力机制的模型可以在自注意力计算过程中感知token之间的相对位置。

Transformer中的自注意力机制

对于一个有n个token的输入序列$x=(x_1, x_2, …, x_n)$,通过自注意力机制后,可以得到一个输出序列$z=(z_1, z_2, …, z_n)$,这个输出序列的每个token的新的表示$z_i$是通过将输入序列的每个token线性转换后再加权求和得到的,这里的权重就是注意力权重,具体可以通过如下的公式表示。

$$
z_i = \sum_{j=1}^{n} \alpha_{ij}(x_i W^V)
$$

$\alpha_{ij}$表示第i个token和第j个token的相关权重,通过softmax计算得到:

$$
\alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{n} exp(e_{ik})}
$$

$e_{ij}$通过scaled dot product进行计算:

$$
e_{ij} = \frac{(x_iW^Q)(x_jW^K)T}{\sqrt{d_z}}
$$

上面是非常经典的自注意力计算公式。$W^Q,W^K,W^V \in R^{d_x \times d_z}$,在Transformer的多头自注意力机制中,这三个参数矩阵在每层和没个head中都不一样。

引入相对位置编码

相对位置编码的思想很简单,即在模型训练和推理过程中,在计算$z_i$时,让模型可以知道第i个词相对于前几个词和后几个词的相对距离。比如对于下面的图,$z_i$前两个词是$z_{i-1},z_{i-2}$,它们相对于$z_i$的距离分别是-1,-2。这种相对距离,作者用向量来表示,其实和绝对位置一样,只不过含义变了,并且参与attention的计算方式也变了。由于每个token相对于其他token都有一个相对位置的信息,并且前向和后向的相对是不一样的,所以这可以将输入序列建模成一个全连接有向图的模型,相对位置即图中的边。

相对位置

对于两个输入元素$x_i$和$x_j$,它们两个的相对位置(或者说是图中的边)可以用向量来表示$a_{ij}^V,a_{ij}^K \in R^{d_a}$。这里有两个向量,分别作用于注意力中的value部分和key部分。加入相对位置信息后,相应的公式可做如下修改:

$$
z_i = \sum_{j=1}^{n} \alpha_{ij}(x_i W^V + a_{ij}^V)
$$

$$
e_{ij} = \frac{(x_iW^Q)(x_jW^K + a_{ij}^K)T}{\sqrt{d_z}}
$$

TransformerXL中,也引入了相对位置编码。从公式看两者的实现有一定的差异,但本质都是在计算attention时能够学习和感知前后的相对位置。

相对位置编码裁剪

由于超过一定距离之后,相对位置的作用就不太有用了,所以在考虑相对位置时,仅考虑一定范围内的距离即可。假设最大的相对位置为k,则一共有$2k+1$个相对位置,分别从-k到+k,每个位置用一个向量表示,具体可表示为:

$$
a_{ij}^K, a_{ij}^V = w_{clip(j-i,k)}^K, w_{clip(j-i,k)}^V
$$

$$
clip(x, k) = max(-k, min(k, x))
$$

其中的$w^K=(w_{-k}^K,…,w_k^K), w^V=(w_{-k}^V, …, w_{k}^V), w_i^K,w_i^V \in R^{d_a}$,所以最终的相对位置矩阵维度是$w^K,w^V \in R^{2k+1 \times d_a}$。

为什么绝对位置编码不包含相对位置信息?

why?