带你轻松理解 Transformer(上)

这是我参与更文挑战的第23天,活动详情查看: 更文挑战

引言

Transformer 模型内部细节很多,本文只主要讲解 Attention 部分和 Self-Attention 部分,如果感兴趣可以查看论文。

什么是 Transformer

  • Transformer 是一个 Seq2Seq 模型,很适合机器翻译任务。不了解 Seq2Seq 模型的,可以看我之前的文章 《Seq2Seq 训练和预测详解以及优化技巧》

  • 它不是循环神经网络结构,而是单纯靠 Attention 、Self-Attention 和全连接层拼接而成的网络结构。

  • Transformer 的测评性能完全碾压最好的 RNN + Attention 结构,目前业内已经没有人用 RNN ,而是用 BERT + Transformer 模型组合。

回顾 RNN + Attention 结构

transformer-rnn-attention.jpg

如图所示是一个 RNN + Attention 组合而成的模型,在 Decoder 的过程中计算 cj 的过程如下:

a)将 Decoder 的第 j 时刻的输出向量 sj 与 WQ 相乘得到一个 q:j

b)将每个 Encoder 的隐层输出 hi 与 WK 相乘得到 k:i ,因为有 m 个输入,所以有 m 个 k:i 向量,用 K 表示。

c)用 KT 与 q:j 相乘,可以得到一个 m 维的向量,然后经过 Softmax 就可以得到 m 个权重 aij

【解释】q:j 被成为 Query ,k:i 被称为 Key,Query 的作用是用来匹配 Key ,Key 的作用是被 Query 匹配,经过计算得到的权重 aij 表示的就是 Query 和每个 Key 的匹配程度,匹配程度越高 aij 越大。我认为可以这样理解, Query 捕获的是 Decoder 的 sj 特征,Key 捕获的是 Encoder 输出的 m 个 hi 的特征,aij 表示的就是 sj 与每个 hi 的相关性。

d)将每个 Encoder 的隐层输出 hi 与 WV 相乘得到 v:i ,因为有 m 个输入,所以有 m 个 v:i 向量,用 V 表示。

e)经过以上的步骤,Decoder 的第 j 时刻的 cj 可以计算出来,也就是 m 个 a 和对应的 m 个 v 相乘,求加权平均得到的。

【注意】WV、WK、WQ 三个参数矩阵是需要从训练数据中学习。

Transformer 中的 Attention

在 Transformer 中移除了 RNN 结构,只保留了 Attention 结构,可以从下图中看出,使用 WK 和 WV 与 Encoder 的 m 个输入 x 进行计算分别得到 m 个 k:i 和 m 个 v:i 。使用 WQ 和 Decoder 的 t 个输入 x 进行计算得到 t 个 q:t

transformer-attention-1.jpg

如下图,这里是计算 Decoder 第 1 个时刻的权重, 将 KT 与 q:1 相乘,经过 Softmax 转化得到 m 个权重 a ,记做 a :1

transformer-attention-2.jpg

如下图,这里是计算 Decoder 第 1 个时刻的上下文特征 c:1 ,将 m 个权重 a 与 m 个 v 分别相乘求和,得到加权平均结果即为 c:1

transformer-attention-3.jpg

类似地,Decoder 的每个时刻的上下文特征都可以和上面一样计算出来。说白了 c:j 依赖于当前的 Decoder 输入 xj 以及所有的 Encoder 输入 [x1,…,xm] 。

transformer-attention-4.jpg

总结如下图,Encoder 的输入是序列 X ,Decoder 的输入是序列 X ,上下文向量 C 是关于 X 和 X 的函数结果,其中用到的三个参数矩阵 WQ 、WK 、WV 都是需要通过训练数据进行学习的。

transformer-attention-5.jpg

下图所示是机器翻译的解码过程,Transformer 的这个过程和 RNN 的过程类似,RNN 是将状态向量 h:j 输入到 Softmax 分类器,只不过 Attention 是将上下文特征 c:j 输入到 Softmax 分类器,然后随机抽样可以预测到下一个单词的输入。

transformer-attention-6.jpg

Transformer 中的 Self-Attention

Self-Attention 的输入只需要一个 X 输入序列,这里分别用 WQ 、WK 、WV 与每个输入 xi 进行计算得到 m 个 q:i、 k:i、 v:i 三个向量。而第 j 时刻的权重和上面的计算方式一样,Softmax(KT * q:j) 可以得到第 j 时刻的 xj 关于所有输入 X 的 m 个权重参数,最后将 m 个权重参数与 m 个 v:i 分别相乘求和,即可得到上下文向量 c:j

transformer-self-attention-1.jpg

类似的,所有时刻的 c:j 都可以用同样的方法求出来。

transformer-self-attention-2.jpg

总结,输入是序列 X ,上下文向量 C 是关于 X 和 X 的函数结果,因为每次在计算 xj 的上下文向量 cj的时候,都是需要将 xj 与所有 X 一起考虑进去并进行计算。其中用到的三个参数矩阵 WQ 、WK 、WV 都是需要通过训练数据进行学习的。

transformer-self-attention-3.jpg

参考

[1] Vaswani A , Shazeer N , Parmar N , et al. Attention Is All You Need[J]. arXiv, 2017.

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享