MegEngine TensorCore 卷积算子实现原理

作者:章晓 | 旷视 MegEngine 架构师

前言

2020 年 5 月 Nvidia 发布了新一代的 GPU 架构安培(Ampere)。其中和深度学习关系最密切的莫过于性能强劲的第三代的 TensorCore ,新一代的 TensorCore 支持了更为丰富的 DL(Deep Learning)数据类型,包括了新的 TesorFloat-32(TF32),Bfloat16(BF16)计算单元以及 INT8, INT4 和 INT1 的计算单元,这些计算单元为 DL 推理提供了全面的支持。为了发挥这些计算单元的能力,以往会由资深的 HPC 工程师手写 GPU 汇编实现的卷积、矩阵乘算子来挖掘硬件的能力。然而凭借人力手工优化算子的方式已经没有办法应对如此多的数据类型,因此对于 DL 应用的优化渐渐地越来越依赖一些自动化的工具,例如面向深度学习领域的编译器。在这样的趋势下, Nvidia 开发了线性代数模板库 CUTLASS ,抽象了一系列高性能的基本组件,可以用于生成各种数据类型,各种计算单元的卷积、矩阵乘算子。 MegEngine 在 CUTLASS 的基础上进行了二次开发,可以高效地开发新的高性能的算子,快速地迁移到新的 GPU 架构。在上一篇 文章 中,我们已经简单介绍了 MegEngine 的底层卷积算子实现的使用方法,而本文将会深入介绍 MegEngine CUDA 平台的底层卷积算子的实现原理,并将会对 Nvidia CUTLASS 的 Implicit GEMM 卷积 文档 进行解读和补充。

因此,读者在阅读本文之前必须要了解的 CUDA 知识有:

  • 访问全局存储(Global Memory)时,同一 Warp 中的相邻线程访问连续的地址,访存请求会被合并,合并的访存能够最大化 Global Memory 的吞吐。
  • 访问 Global Memory 时,尽可能使用最宽的数据类型(float4)进行访问,这样可以最大化访存指令的利用率。
  • CUDA 的共享存储(Shared Memory)按照每 4Bytes 划分为一个 bank,共分为 32 个 bank。当同一 Warp 中的线程访问同一 bank 的不同地址时会发生冲突(bank conflict)。无 bank conflict 的访存模式才能最大化 Shared Memory 的吞吐。
  • GPU 有显存(Global Memory)、L2、L1(Shared Memory)、寄存器 4 个层次的存储,直接访问显存的延迟很高,在优化 GEMM、Convolution 这样的计算密集型算子时,需要
    • 通过 L1 和寄存器的缓存来减少 Global Memory 的访存请求。
    • 通过大量的计算来隐藏不可避免的 Global Memory 访存延迟。

首先,我们需要了解 CUTLASS 引入的一些抽象概念

  • TileIterator : 用于访问存储中的一个Tile的数据。TileIterator 实现了advance()方法,支持在 Matrix , Tensor 等数据类型上进行遍历。
  • Fragment : 数组类型,用于存放 TileIterator 读取进来的数据。 Fragment 的数据通常存放在寄存器中。

然后我们简单回顾一下 CUTLASS 设计的高性能的 GEMM 算子的 Pipeline,按照 Pipeline 实现的算子能够在 CUDA 平台上达到 cublas 的 90% 以上的性能。下图演示了 CUTLASS 设计的 Pipeline 化的 GEMM 算子:

1.png

  1. 图中第一行演示了由 PredicatedTileIteratorSmemTileIterator 配合完成从 Global Memory 到 Shared Memory 的数据搬运。
  2. 第二行演示了 WarpTileIterator 负责从 Shared Memory 搬运数据到 Fragment 寄存器中。
  3. 第三行展示了WarpMmaOperatorFragment 寄存器中的矩阵数据执行矩阵乘加 (Matrix-Multiply-Add) 操作。

Implicit GEMM 算法

卷积映射为矩阵乘法

我们首先来看一下前向卷积算子的定义,假设输入的 feature map 是 x,卷积层的 weight 是 w,输出是 y,其中 x,y,w 都是 4 维的 Tensor,x 的四个维度分别是 NxICxIHxIW,w 的四个维度分别是 OCxICxFHxFW,y 的四个维度分别是 NxOCxOHxOW。那么输出 y 和输入 x, w 的数学关系式可以写成

y(n,oc,oh,ow)=icfhfwx(n,ic,ih,iw)w(oc,ic,fh,fw)\text{y}( \text{n}, \text{oc}, \text{oh}, \text{ow} ) = \sum_{\text{ic}} \sum_{\text{fh}} \sum_{\text{fw}} \text{x} (\text{n}, \text{ic}, \text{ih}, \text{iw}) \cdot \text{w} ( \text{oc}, \text{ic}, \text{fh}, \text{fw} )

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享