设计 kernel 使用到的核心公式

1. 某个元素的全局索引(计算扁平化索引)

维度为 [d0, d1, d2] 的3维数组,各维度索引 (i, j, k),且元素在内存中是连续的,那么位于局部索引的元素的全局ID计算公式为:

[d0, d1, d2]
(i,  j,  k )
  • 行主序 (d2维变化(k变化)最快) 全局 ID = 1*k + d2*j + (d1*d2) *i
  • 列主序 (d0维变化(i变化)最快) 全局 ID = 1*i + d0*j + (d0*d1) *k

行主序、列主序字面上易误解,因为多维数组并不只有行和列。更准确的名字应该是 内主序外主序

理解:从变化最快的索引开始,依次向外维累加,以列主序为例:

  • 第一项,索引是ii对应的维度已经是变化最快的了,这里的乘数是 1;
  • 第二项,索引是j(i外围索引),另一项是j对应维度内侧所有维度累积,另一项是 j 对应维度元素的 stride,即d0
  • 第三项,索引是k(j的外围索引),另一项是k对应维度内侧所有维度累积,另一项是 k 对应维度元素的 stride, 即d0*d1

每一项是当前维度的索引 * 这个维度中元素的 stride。这是一般情况,当这个 tensor 中的元素不是连续的时,每一项中的 stride 是字节 stride,而非元素 stride。

也就是说,当不连续存储的 tensor 时,只要每个维度的字节 stride 正确给出,都可以使用这个公式计算全局索引(或者是字节偏移)。

把这两个公式作为思想钢印印在脑子里。

2. 多少个 a 可以覆盖 b

block_num_y = (2048 + 4*128 - 1) / (4*128);

的意思是保证 block y 方向的个数可以覆盖 2048,上述公式实现的是 ceiling 除法,表示:多少个 4*128 可以覆盖 2048,向上取整。

C++ 中 2048/512 = (2048+512-1)/512 = 4. 2048 恰好可以被 512 整除。但如果是 2049呢?

C++ 中 2049/512 = 4 != (2049+512-1)/512 = 5 2049 中余下一个元素被漏掉了。

所以涉及到上述语义(多少个 a 可以覆盖 b )时,使用 Ceilling 除法 (b+a-1)/a

3. id/num1id%num1

上述只是简写,完整的表达是(假设 shape=[num0,num1,num2], col-major, id 是全局索引):

x 维度内纬度中数值索引    = (id / [x维度中数值的stride]) % [这个维度x]

num0 维度内纬度中数值索引 = (id/1)          %num0 
num1 维度内纬度中数值索引 = (id/1*num0)     %num1
num2 维度内纬度中数值索引 = (id/1*num0*num1)%num2

每个维度之间的索引无耦合

比如 num1=2(索引0,1),num2=3(索引0,1,2), 共 2*3=6 个元素,id=0,1,2,3,4,5

  • (0,1,2,3,4,5)%num1=(0,1, 0,1, 0,1) 当前乘数 num1 维度的索引
  • (0,1,2,3,4,5)/num1=(0,0, 1,1, 2,2) 另一个维度 num2 的索引

完整形式

  • ((0,1,2,3,4,5)/1) %num1 = (0,1, 0,1, 0,1)
  • ((0,1,2,3,4,5)/num1)%num2 = (0,0, 1,1, 2,2)

实际场景中,要根据根据 row-majro、col-major 确定除数是谁。具体地,如 shape 是 [2,4] 的8个数,row-major 存储,则:

id:             0,1,2,3,  4,5,6,7,
col_id = id%4 = 0,1,2,3,  0,1,2,3,
row_id = id/4 = 0,0,0,0,  1,1,1,1, 

*** 完整形式:
(id/1)%4      = 0,1,2,3,  0,1,2,3,  = id%4
(id/4)%2      = 0,0,0,0,  1,1,1,1,  = id/4

col-major 存储(物理存储),则:

id:             0,2,4,6,  1,3,5,7,
col_id = id%2 = 0,0,0,0,  1,1,1,1,
row_id = id/2 = 0,1,2,3,  0,1,2,3,

*** 完整形式:
(id/1)%2      = 0,0,0,0,  1,1,1,1,  = id%2
(id/2)%4      = 0,1,2,3,  0,1,2,3,  = id/2

col-major 存储(只是横着写,意义与上述一致):

id:             0,1  2,3  4,5  6,7
row_id = id/2 = 0,0  1,1  2,2  3,3
col_id = id%2 = 0,1  0,1  0,1  0,1

实际中,除数应该是两个维度中变化较快的那个维。将上述作为思想钢印印在脑中。

每个元素都有 row_id 和 col_id

上述公式仅适用于连续存储的 tensor。

4. 非连续存储的 tensor 计算各个维度的索引

i是扁平化后的全局索引,shape=[ne00,ne01,ne02,ne03],stride=[s00(1),s01,s02,s03]

    const int64_t s03 = ne00*ne01*ne02;
    const int64_t s02 = ne00*ne01;
    const int64_t s01 = ne00;

    const int64_t i03 = i / s03;
    const int64_t i02 = (i - i03*s03 ) / s02;
    const int64_t i01 = (i - i03*s03  -  i02*s02) / s01;
    const int64_t i00 = (i - i03*s03 - i02*s02 - i01*s01) / s00;

5. 只有结合 block 配置和 grid 配置的计算

与任务和具体 block/grid 配置有关, 如何相关?画一个小 case 小配置 就知道了。

实例(扁平化):

grid (13,4,1), block (128,1,1) grid 中一列 block 的 thread id 这样计算:

(0,1,2,3,...,127) + 128 * (0,1,2,3)^T 正好是(有 T 表示竖着排列,没有 T 表示横着排), 将这个计算补充完整得到,2D config中thread 扁平化后的索引:

0  , 1  , 2  ,..., 126, 127,
128, 129, 130,..., 254, 255,
256, 257, 258,..., 382, 383,
384, 385, 386,..., 510, 511

进而推导出索引公式是:int i0 = threadIdx.x + blockDim.x * blockIdx.y; 可以看出依然是核心公式1,因为们的语义是一样的

内存访问模式 threadIdx.x (变化最快的)它访问连续的内存地址。故这是好的内存访问模式。

当前 case 每个 thread 要访问连续的4个值,所以:

0   , 4   , 8   , ..., 504 , 508
512 , 516 , 520 , ..., 1016, 1020
1024, 1028, 1032, ..., 1524, 1528
1536, 1540, 1544, ..., 2020, 2024

推导出是 i0=i0*4, 得到每个 thread 访问连续 4 个数据中第一个数据索引:

int i0 = (threadIdx.x + blockDim.x * blockIdx.y)*4;

覆盖 2048 个元素。形式也是 [idx * num + idy],这里的 num 是 128

对于数据索引 i0,前 512 个数由 128 个 thread 负责,将这 128 个 thread 分为4(组)个 warp(32thread),每组负责 128个元素(4组覆盖这512个数)。

对于每一组(32 thread),每组 8 个 thread 负责32个数的归约求和,所以有这样的4组。

各个索引推导流程

两种索引

一个 kernel 中一定有两种索引:数据访问存储索引(包括Shared memory) & 线程索引。而前者一定是后者的线性组合(这就是两种索引的关系)。

故对于计算索引的方法论是:画出读写示意图,倒推各个索引,kernel 一定是需要输入输出访问存储 id,而这个最终 id 一定是核心公式中的一个(根据row-major、col-mahor选择)。有了公式,找其中的各个项。… 如此一步一步向前找,最终一定是回到 CUDA built-in 的线程 id。

所以上述过程是由一个映射的:从线程id 到数据访问存储 id。

id 映射

给出一个小规模的实例,计算出所有元素的 id,包括公式中各项的 id。

判断bank conflict的公式

声明: __shared__ float tile[32][32] , 内存(row-major)访问模式是 tile[threadIdx.x][threadIdx.y] 则:

  • tile 一维地址:tile_id = threadIdx.x * 32 + threadIdx.y
  • warp 中线程访问字节地址是:byte_addr = 0 + tile_id * 4B
  • 对应的 bank_id 是:byte_addr//4%32

化简后:

tile_id 与 bank_id 的关系:bank_id = 0 + tile_id % 32