Sve算子解释
这里记录了关于为ggml_gemv_q4_0_4x8_q8_0函数添加sve计算方式的过程
sve 算子
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
//only 128bits support
if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == 16) {
//这里准备一个用于调换数组1,2号元素顺序的mask,以及8bit,16bit,32bit的mask
uint32_t indices[] = {0, 2, 1, 3};
svuint32_t index_vec = svld1_u32(svptrue_b32(), indices);
const svbool_t pg_b8 = svptrue_b8();
const svbool_t pg_b16 = svptrue_b16();
const svbool_t pg_b32 = svptrue_b32();
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) {
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
svfloat32_t acc= svdup_n_f32(0.0f);
for (int b = 0; b < nb; b++) {
//与neon相同的取数过程,利用指针偏移和svld1,为128bit寄存器填充16个int8_t
const svint8_t b0 = svld1_s8(pg_b8,(const int8_t *) b_ptr->qs);
const svint8_t b1 = svld1_s8(pg_b8,(const int8_t *) b_ptr->qs + 16);
const svint8_t b2 = svld1_s8(pg_b8,(const int8_t *) b_ptr->qs + 32);
const svint8_t b3 = svld1_s8(pg_b8,(const int8_t *) b_ptr->qs + 48);
svfloat16_t bd = svld1_f16(pg_b16, (const __fp16 *)b_ptr->d);//取出f16共8个,但只有前四个有用
svfloat16_t ad = svdup_n_f16(*(const __fp16 *) &a_ptr->d);//取出f16复制8次
svfloat32_t scale = svmul_f32_x(pg_b32,svcvt_f32_f16_x(pg_b16, ad),svcvt_f32_f16_x(pg_b16, svzip1(bd, bd)));
//先通过zip把bd的数据每个复制2次,[A,B,C,D]=>[A,A,B,B,C,C,D,D],这样寄存器中只保留前四个数据,再和ad的数据都转化为f32,四个f32相乘得到缩放系数,相对于原有的4个f16相乘,这里多算了
//与neon相同的取数过程,利用指针偏移和svdup,为128bit寄存器填充2个64bit数,每个64bit中包含8个int8,2个64bit相同
const svint8_t a0 =svreinterpret_s8_u64(svdup_n_u64(*((const uint64_t *)a_ptr->qs)));
const svint8_t a1 =svreinterpret_s8_u64(svdup_n_u64(*((const uint64_t *)a_ptr->qs+1)));
const svint8_t a2 =svreinterpret_s8_u64(svdup_n_u64(*((const uint64_t *)a_ptr->qs+2)));
const svint8_t a3 =svreinterpret_s8_u64(svdup_n_u64(*((const uint64_t *)a_ptr->qs+3)));
svint32_t ret0 = svdup_n_s32(0);
svint32_t ret1 = svdup_n_s32(0);
//与neon相同的dot指令,每四个对应的int8数据相乘加到一个int32中
ret0 = svdot_s32(ret0, b0 << 4, a0);
ret1 = svdot_s32(ret1, b1 << 4, a0);
ret0 = svdot_s32(ret0, b2 << 4, a1);
ret1 = svdot_s32(ret1, b3 << 4, a1);
ret0 = svdot_s32(ret0, b0 & 0xf0U, a2);
ret1 = svdot_s32(ret1, b1 & 0xf0U, a2);
ret0 = svdot_s32(ret0, b2 & 0xf0U, a3);
ret1 = svdot_s32(ret1, b3 & 0xf0U, a3);
//由于sve没有neon的水平相加指令,即[A,B,C,D][E,F,G,H]=>[A+B,C+D,E+F,G+H];
//sve2的成对相加是[A,B,C,D][E,F,G,H]=>[A+B,E+F,C+D,G+H];所以需要sve调换顺序。
// uint32_t indices[] = {0, 2, 1, 3};
// svuint32_t index_vec = svld1_u32(svptrue_b32(), indices);
svint32_t ret = svtbl_s32(svaddp_s32_m(pg_b32, ret0, ret1),index_vec);
//这里还记录了2种只使用sve命令的实现水平相加效果的命令
// svint32_t ret = svadd_s32_m(svptrue_b32(), svuzp1_s32(ret0, ret1), svuzp2_s32(ret0, ret1));
// svint32_t ret = svadd_s32_m(svptrue_b32(), svzip1_s32(ret0, ret1), svzip2_s32(ret0, ret1));
//与neon相同的mla过程,将ret结果和缩放系数相乘,还原为f32
acc = svmla_f32_x(pg_b32,acc,svcvt_f32_s32_x(pg_b32,ret>>4),scale);
a_ptr++;
b_ptr++;
}
//与neon相同的过程,将数据存储回指针中
svst1_f32(pg_b32,s, acc);
s += ncols_interleaved;
}
return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE2)
原neon算子
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) {
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
float32x4_t acc = vdupq_n_f32(0);
for (int b = 0; b < nb; b++) {
//取数过程
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);//取出4个f16
int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);//取出1个f16,复制4次
int32x4_t ret0 = vdupq_n_s32(0);
int32x4_t ret1 = vdupq_n_s32(0);
//计算过程
ret0 = vdotq_s32(ret0, b0 << 4, a0);
ret1 = vdotq_s32(ret1, b1 << 4, a0);
ret0 = vdotq_s32(ret0, b2 << 4, a1);
ret1 = vdotq_s32(ret1, b3 << 4, a1);
ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
//水平相加
int32x4_t ret = vpaddq_s32(ret0, ret1);
//利用scale将int32转换为f32
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));//这里是先计算了4个f16相乘,再计算和int32数据的相乘
a_ptr++;
b_ptr++;
}
//储存过程
vst1q_f32(s, acc);
s += ncols_interleaved;
}
return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)