上期讲到了Clickhoue性能高的原因是应用了向量化技术,也就是SIMD(Single Instruction Multiple Data,单指令多数据流)的指令来一次处理多个数据。这么讲有点抽象,今天就用几个例子感受一下SIMD具体是怎么运作的。
本文例子来自dendibakh的博客系列,一个非常好的介绍SIMD的系列,十分推荐。
先来看个简单的例子
#include <vector>
void foo( std::vector<unsigned>& lhs, std::vector<unsigned>& rhs )
{
for( unsigned i = 0; i < lhs.size(); i++ )
{
lhs[i] = ( rhs[i] + 1 ) >> 1;
}
}
我们用clang加上 -O2 -march=core-avx2 -std=c++14 -fno-unroll-loops
的编译选项来编译上面这段代码,-fno-unroll-loops
告诉编译器禁止做循环展开的优化,-march=core-avx2
提示编译器使用 avx2
指令集来做向量话。编译生成的汇编代码可以直观的在Compiler Explorer看到结果。
我们可以看到,循环内的代码实际上生成来两个版本,一个是正常版
mov edx, dword ptr [r9 + 4*rsi] # loading rhs[i]
add edx, 1 # rhs[i] + 1
shr edx # (rhs[i] + 1) >> 1
mov dword ptr [rax + 4*rsi], edx # store result to lhs
mov esi, edi
add edi, 1 # incrementing i by 1
cmp rcx, rsi
ja <next iteration>
一个是使用了AVX2指令的向量化版本
vmovdqu ymm1, ymmword ptr [r9 + 4*rdi] # loading 256 bits from rhs
vpsubd ymm1, ymm1, ymm0 # ymm0 has all bits set, like +1
vpsrld ymm1, ymm1, 1 # vector shift right.
vmovdqu ymmword ptr [rax + 4*rdi], ymm1 # storing result
add rdi, 8 # incrementing i by 8
cmp rsi, rdi
jne <next iteration>
在进入循环的时候会先判断lhs
未处理的元素的大小是否大于8,若大于8个,则使用向量化的版本一次循环处理8个数字,若到后面小于8个待处理的元素,则使用正常的版本一个个数字处理。
其实大部分的汇编代码模式都是类似的,遵循 加载数据->处理数据->写回数据 这样的模式。对比两个版本的代码也可以看出,向量版的也是这样的模式,只是指令变成了可以一次操作多个数的版本, mov
变成了vmovdqu
, shr
变成了vpsrld
之类的。寄存器使用的也是位数更多的寄存器(256位的YMM寄存器或128位的XMM寄存器)。
下面我们再来详细看一个复杂点的例子。(godbolt)
int foo( short a[16], unsigned short b[16], unsigned short bias[16] )
{
int agg = 0;
for( int i = 0; i < 16; i++ )
{
if( a[i] > 0 )
a[i] = (bias[i] + a[i]) * b[i] >> 16;
else
a[i] = - ((bias[i] - a[i]) * b[i] >> 16);
agg += a[i];
}
return agg;
}
我们一段一段来看:
vmovdqu ymm4,YMMWORD PTR [rdi] # a[]
vmovdqu ymm2,YMMWORD PTR [rdx] # bias[]
vmovdqu ymm5,YMMWORD PTR [rsi] # b[]
vmovdqa ymm6,YMMWORD PTR [rip+0x5c6] # 400b00
上面的指令只是单纯的加载数据到YMM寄存器里面,执行之后结果如下
由于16位数的相加可能会移除,所以我们要把16位的整数扩展成32位的整数。
vpmovzxwd ymm1,xmm2
vextracti128 xmm8,ymm4,0x1
vextracti128 xmm2,ymm2,0x1
vpmovzxwd ymm9,xmm5
vpmovzxwd ymm2,xmm2
vpmovsxwd ymm0,xmm4
vpmovsxwd ymm8,xmm8
vextracti128 xmm5,ymm5,0x1
上面一顿操作之后,寄存器里面就变成了
当数据都准备好后,就可以用vpaddd
, vpsubd
, vpmuld
等指令批量做运算了
vpaddd ymm3,ymm1,ymm0
vpmovzxwd ymm5,xmm5
vpsubd ymm0,ymm1,ymm0
vpsubd ymm1,ymm2,ymm8
vpmulld ymm0,ymm0,ymm9
vpmulld ymm1,ymm1,ymm5
vpaddd ymm7,ymm2,ymm8
vpmulld ymm3,ymm3,ymm9
vpmulld ymm7,ymm7,ymm5
结果如下
注意,我们的循环代码里面有个if分支,那向量优化是如何做遇上if又是怎么做的呢?从上图的结果来看,十分的暴力,直接就都算了两个分支的结果,(bias[i] + a[i]) * b[i]
的结果用result of +
表示,保存在ymm3和ymm7。(bias[i] - a[i]) * b[i]
的结果用result of - *
表示,保存在ymm0和ymm1。
后面还要再接一个位移操作
vpsrad ymm0,ymm0,0x10
vpsrad ymm1,ymm1,0x10
vpand ymm1,ymm6,ymm1
vpsrad ymm3,ymm3,0x10
vpand ymm0,ymm6,ymm0
vpackusdw ymm0,ymm0,ymm1
vpsrad ymm7,ymm7,0x10
vpxor xmm1,xmm1,xmm1
把32位的结果打包变回16位整数。
vpcmpgtw ymm4,ymm4,ymm1
vpand ymm3,ymm6,ymm3
vpand ymm7,ymm6,ymm7
vpackusdw ymm3,ymm3,ymm7
指令vpcmpgtw ymm4,ymm4,ymm1
一次性比较ymm4和ymm1中每个16位整数的大小,结果保存回ymm4里面,也就是代码里面的if( a[i] > 0 )
判断。
vpermq ymm0,ymm0,0xd8
vpermq ymm3,ymm3,0xd8
vpsubw ymm0,ymm1,ymm0
vpblendvb ymm0,ymm0,ymm3,ymm4
vpmovsxwd ymm1,xmm0
vmovdqu YMMWORD PTR [rdi],ymm0
vpermq
用于整理数据,vpsubw ymm0,ymm1,ymm0
把ymm0的数据取负数。
最关键的是vpblendvb ymm0,ymm0,ymm3,ymm4
,这条指令会根据ymm3的标示,从ymm0或者ymm3取数据到ymm0,由此就实现了if的功能。
后面还有累加ymm0的结果到变量agg
的代码就忽略了。
有时候数据之间会有依赖关系,不能简单的并行化,就比如
void foo( unsigned short * a, unsigned short * b )
{
for( int i = 0; i < 128; i++ )
{
a[i] += b[i];
}
}
我们调用的foo
函数的时候有
unsigned short x[] = {1, 1, 1, 1, ... , 1}; // 129 elements
unsigned short* a = x + 1;
unsgined short* b = x;
foo (a, b);
正常的计算由于a和b是指向的是同一个数组,会得出结果 x = {1, 2, 3, 4, 5, ... }
,但若是向量化的话,就会得出x = (2, 2, 2, 2, ...)
的结果。
但实际上我们的计算结果是正确的,那是因为编译器在会为在计算前为我们做个变量的别名检查(godbolt)
lea rax, [rsi + 256] # calculating the end of b (b + 128)
cmp rax, rdi # comparing the beginning of a and the end of b
jbe .LBB0_4
lea rax, [rdi + 256] # calculating the end of a (a + 128)
cmp rax, rsi # comparing the beginning of b and the end of a
jbe .LBB0_4
xor eax, eax
<scalar version>
.LBB0_4:
<vector version>`
若发现这种变量别名的时候就会走正常的版本,确保计算结果的正确性,但这也引入了检查的开销。
当然如果你确定你的程序不会出现这种情况,可以通过__restrict__
关键字取消检查(link to godbolt)
其实大部分时候编译器都会帮助我们自动的做向量化的优化,我们是不需要自己手动使用SIMD指令去写代码的。不过了解一点相关知识还是有助于写出性能更高的代码的。