Home

SIMD直观介绍

上期讲到了Clickhoue性能高的原因是应用了向量化技术,也就是SIMD(Single Instruction Multiple Data,单指令多数据流)的指令来一次处理多个数据。这么讲有点抽象,今天就用几个例子感受一下SIMD具体是怎么运作的。

本文例子来自dendibakh的博客系列,一个非常好的介绍SIMD的系列,十分推荐。

First Look

先来看个简单的例子

#include <vector>
void foo( std::vector<unsigned>& lhs, std::vector<unsigned>& rhs )
{
    for( unsigned i = 0; i < lhs.size(); i++ )
    {
            lhs[i] = ( rhs[i] + 1 ) >> 1;
    }
}

我们用clang加上 -O2 -march=core-avx2 -std=c++14 -fno-unroll-loops 的编译选项来编译上面这段代码,-fno-unroll-loops告诉编译器禁止做循环展开的优化,-march=core-avx2提示编译器使用 avx2 指令集来做向量话。编译生成的汇编代码可以直观的在Compiler Explorer看到结果。

我们可以看到,循环内的代码实际上生成来两个版本,一个是正常版

mov edx, dword ptr [r9 + 4*rsi]         # loading rhs[i]
add edx, 1                              # rhs[i] + 1
shr edx                                 # (rhs[i] + 1) >> 1
mov dword ptr [rax + 4*rsi], edx        # store result to lhs
mov esi, edi
add edi, 1                              # incrementing i by 1
cmp rcx, rsi
ja <next iteration>

一个是使用了AVX2指令的向量化版本

vmovdqu ymm1, ymmword ptr [r9 + 4*rdi]  # loading 256 bits from rhs
vpsubd ymm1, ymm1, ymm0                 # ymm0 has all bits set, like +1
vpsrld ymm1, ymm1, 1                    # vector shift right.
vmovdqu ymmword ptr [rax + 4*rdi], ymm1 # storing result 
add rdi, 8                              # incrementing i by 8
cmp rsi, rdi
jne <next iteration>

在进入循环的时候会先判断lhs未处理的元素的大小是否大于8,若大于8个,则使用向量化的版本一次循环处理8个数字,若到后面小于8个待处理的元素,则使用正常的版本一个个数字处理。

其实大部分的汇编代码模式都是类似的,遵循 加载数据->处理数据->写回数据 这样的模式。对比两个版本的代码也可以看出,向量版的也是这样的模式,只是指令变成了可以一次操作多个数的版本, mov变成了vmovdqu, shr变成了vpsrld之类的。寄存器使用的也是位数更多的寄存器(256位的YMM寄存器或128位的XMM寄存器)。

更复杂的例子

下面我们再来详细看一个复杂点的例子。(godbolt)

int foo( short a[16], unsigned short b[16], unsigned short bias[16] )
{
  int agg = 0;
  for( int i = 0; i < 16; i++ )
  { 
    if( a[i] > 0 ) 
      a[i] = (bias[i] + a[i]) * b[i] >> 16; 
    else 
      a[i] = - ((bias[i] - a[i]) * b[i] >> 16); 
    agg += a[i]; 
  }
  return agg;
}

我们一段一段来看:

加载数据

 vmovdqu ymm4,YMMWORD PTR [rdi] # a[]
 vmovdqu ymm2,YMMWORD PTR [rdx] # bias[]
 vmovdqu ymm5,YMMWORD PTR [rsi] # b[]
 vmovdqa ymm6,YMMWORD PTR [rip+0x5c6] # 400b00

上面的指令只是单纯的加载数据到YMM寄存器里面,执行之后结果如下

解压

由于16位数的相加可能会移除,所以我们要把16位的整数扩展成32位的整数。

vpmovzxwd ymm1,xmm2
vextracti128 xmm8,ymm4,0x1
vextracti128 xmm2,ymm2,0x1

vpmovzxwd ymm9,xmm5
vpmovzxwd ymm2,xmm2
vpmovsxwd ymm0,xmm4
vpmovsxwd ymm8,xmm8
vextracti128 xmm5,ymm5,0x1

上面一顿操作之后,寄存器里面就变成了

计算

当数据都准备好后,就可以用vpaddd, vpsubd, vpmuld等指令批量做运算了

vpaddd ymm3,ymm1,ymm0
vpmovzxwd ymm5,xmm5
vpsubd ymm0,ymm1,ymm0
vpsubd ymm1,ymm2,ymm8
vpmulld ymm0,ymm0,ymm9
vpmulld ymm1,ymm1,ymm5
vpaddd ymm7,ymm2,ymm8
vpmulld ymm3,ymm3,ymm9
vpmulld ymm7,ymm7,ymm5

结果如下

注意,我们的循环代码里面有个if分支,那向量优化是如何做遇上if又是怎么做的呢?从上图的结果来看,十分的暴力,直接就都算了两个分支的结果,(bias[i] + a[i]) * b[i]的结果用result of +表示,保存在ymm3和ymm7。(bias[i] - a[i]) * b[i]的结果用result of - *表示,保存在ymm0和ymm1。

后面还要再接一个位移操作

vpsrad ymm0,ymm0,0x10
vpsrad ymm1,ymm1,0x10
vpand ymm1,ymm6,ymm1
vpsrad ymm3,ymm3,0x10
vpand ymm0,ymm6,ymm0

打包

vpackusdw ymm0,ymm0,ymm1
vpsrad ymm7,ymm7,0x10
vpxor xmm1,xmm1,xmm1

把32位的结果打包变回16位整数。

处理If语句

vpcmpgtw ymm4,ymm4,ymm1
vpand ymm3,ymm6,ymm3
vpand ymm7,ymm6,ymm7
vpackusdw ymm3,ymm3,ymm7

指令vpcmpgtw ymm4,ymm4,ymm1一次性比较ymm4和ymm1中每个16位整数的大小,结果保存回ymm4里面,也就是代码里面的if( a[i] > 0 )判断。

写回

vpermq ymm0,ymm0,0xd8
vpermq ymm3,ymm3,0xd8
vpsubw ymm0,ymm1,ymm0
vpblendvb ymm0,ymm0,ymm3,ymm4
vpmovsxwd ymm1,xmm0
vmovdqu YMMWORD PTR [rdi],ymm0

vpermq用于整理数据,vpsubw ymm0,ymm1,ymm0把ymm0的数据取负数。

最关键的是vpblendvb ymm0,ymm0,ymm3,ymm4,这条指令会根据ymm3的标示,从ymm0或者ymm3取数据到ymm0,由此就实现了if的功能。

后面还有累加ymm0的结果到变量agg的代码就忽略了。

数据依赖

有时候数据之间会有依赖关系,不能简单的并行化,就比如

void foo( unsigned short * a, unsigned short * b )
{
  for( int i = 0; i < 128; i++ )
  {
    a[i] += b[i]; 
  }
}

我们调用的foo函数的时候有

  unsigned short x[] = {1, 1, 1, 1, ... , 1}; // 129 elements
  unsigned short* a = x + 1;
  unsgined short* b = x;
  foo (a, b);

正常的计算由于a和b是指向的是同一个数组,会得出结果 x = {1, 2, 3, 4, 5, ... },但若是向量化的话,就会得出x = (2, 2, 2, 2, ...)的结果。

但实际上我们的计算结果是正确的,那是因为编译器在会为在计算前为我们做个变量的别名检查(godbolt)

lea rax, [rsi + 256] # calculating the end of b (b + 128)
cmp rax, rdi         # comparing the beginning of a and the end of b
jbe .LBB0_4
lea rax, [rdi + 256] # calculating the end of a (a + 128)
cmp rax, rsi         # comparing the beginning of b and the end of a
jbe .LBB0_4
xor eax, eax
  <scalar version>
.LBB0_4:
  <vector version>`

若发现这种变量别名的时候就会走正常的版本,确保计算结果的正确性,但这也引入了检查的开销。

当然如果你确定你的程序不会出现这种情况,可以通过__restrict__关键字取消检查(link to godbolt

最后

其实大部分时候编译器都会帮助我们自动的做向量化的优化,我们是不需要自己手动使用SIMD指令去写代码的。不过了解一点相关知识还是有助于写出性能更高的代码的。