这是我参与更文挑战的第21天,活动详情查看: 更文挑战
引言
之前我们上文介绍了用 Attention 来提升 Seq2Seq 的性能,将 Attention 共同作用于 Seq2Seq 的 Encoder 和 Decoder 两个部分。本文我们介绍 Self-Attention ,可以将 Attention 单独用到其中的一部分里面。
愿论文中 Self-Attention 作用于 LSTM ,这里我简化过程,用 SimpleRNN 代替 LSTM 介绍该思想。
SimpleRNN + Self-Attention 核心原理
【SimpleRNN 求 hi 的方法】
我们之前在 SimpleRNN 中求 hi 的时候,是按照下面的这个公式的思路进行的:
hi = tanh(A * concat(xi, hi-1)+b)
说明当前时刻的隐层状态依赖于当前的输入 xi 和上一时刻的隐层状态输入 hi-1 。
【SimpleRNN + Self-Attention 求 hi 的方法】
当引入 Self-Attention 之后,SimpleRNN 求 hi 的方式发生了变化,是按照下面的这个公式的思路进行的:
hi = tanh(A * concat(xi, ci-1)+b)
图中例子说明 t3 时刻的隐层状态 h3 依赖于当前的输入 x3 和上一时刻的上下文向量 c2 。
其中 ci 就是将第 i 时刻的隐层输出 hi 与已有的 h1、… 、 hi 进行权重计算,得到权重列表 a1、… 、 ai ,最后将这些隐层输出与各自对应的权重参数进行加权平均求和得到 ci 。至于具体的权重计算方法和 Attention 文章中提到的方法一样,这里不再赘述。
从图中的例子可以 c3 是 h1、 h2 、 h3 及各自对应权重 a1、 a2 、 a3 的加权平均和。
另外,可以考虑换更加复杂的计算思路,其他具体过程和上述一样:
hi = tanh(A * concat(xi, ci-1, hi-1,)+b)
总结
-
Self-Attention 和 Attention 一样,都能解决 RNN 类模型的遗忘问题,每次在计算当前隐层输出 hi 的时候,都会用 ci-1 来回顾一下之前的信息,这样就能记住之前的信息。但是 Self-Attention 中的 ci 的计算在自身的 RNN 结构中即可计算,而不像 Seq2Seq 中的 Attention 那样横跨 Decoder 和 Encoder 两个 RNN 结构,即 Decoder 的 ci 依赖于 Encoder 的所有隐层输出。
-
Self-Attention 可以作用于任何 RNN 类的模型来提升性能了,如 LSTM 等。
-
Self-Attention 还能帮助 RNN 关注相关的信息,如下图所示,红色单词是当前的输入,蓝色单词表示与当前输入单词较相关的单词。
参考
Cheng J , Dong L , Lapata M . Long Short-Term Memory-Networks for Machine Reading[C]// Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing. 2016.