Follow Excellent, Success will Chase you

0%

ARMv8中的SIMD运算

NEON是一种压缩的SIMD架构,主要是给多媒体使用,结果并行计算的问题。

NEON是ARMv7-A和ARMv7-R引入的特性,在后面的ARMv8-A和ARMv8-R中也扩展其功能.1288bit的向量运算

ARMv7-A/RARMv8-A/RARMv8-A
AArch32AArch64
Floating-point32-bit16-bit*/32-bit16-bit*/32-bit
Integer8-bit/16-bit/32-bit8-bit/16-bit/32-bit/64-bit8-bit/16-bit/32-bit/64-bit

ARMv8与ARMv7的区别

  • 1.与通用寄存器相同的助记符
CPU通用SIMD
ARMv7mul, r0, r0, r1vmul d0, d0, d1
ARMv8mul x0, x0, x1mul v0.u8, v0.u8, v1.u8

注意:在ARMv7中所有的SIMD汇编的操作码如mul的前缀都有v如vml

  • 2.ARMv8的寄存器是ARMv7的两倍

    • ARMv8拥有32个128-bit寄存器
    • ARMv7拥有16个128-bit寄存器
  • 3.不同的指令语法

SIMD寄存器

armv8SIMD寄存器

寄存器个数位宽数据类型
D寄存器(D0-D3132个64-bit双字(double word)
Q寄存器(Q0-Q1516个128-bit四字

矢量寄存器V0-V31:包装

armv8SIMD寄存器标识vx

打包V0-V31中的数据,方便数据操作

ARMv8SIMD寄存器打包

矢量包装

ARMvc8

主要定义每一个矢量Vn的数据位宽

标识位宽数据类型示例
b8bitcharv0.8b,v0.16b: 8个bit16个bit
h16bitshortv0.4h,v0.8h: 4或8个半字(short类型)
s32bitintv0.2s,v0.4s:2或4个字
d64bitlong longv0.2d:2个double word

指令语法

ARMv8SIMD指令op

1
ld4 {v0.4h-v3.4h}, [%0]

等同于:

1
ld4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0]

内联函数编程

NEON 内在函数在头文件arm_neon.h中定义。头文件既定义内在函数,也定义一组向量类型

NEON操作函数

内嵌汇编编程

1
2
3
4
5
6
7
8
9
asm volatile(
"mnemonic+operand \n\t"
"mnemonic+operand \n\t"
"mnemonic+operand \n\t"

: //output operand list /*输出操作数列表*/
: //input operand list /*输入操作数列表*/
: //Dirty registers etc /*被改变资源列表*/
);

操作符&修饰符

1
2
3
4
5
6
asm volatile(
"add %0, %1, %2"

: "=r" (ret)
: "r" (a), "r" (b)
);
操作符含义
r通用寄存器
m一个有效的内存地址
I数据处理中的立即数
X被修饰的操作符只能作为输出
修饰符含义
只读
=只写
+可读可写
&只能作为输出

传参

参数序列

1
2
3
4
5
6
asm volatile(
"add %0, %1, %2"

: "=r" (ret)
: "r" (a), "r" (b)
);
  • ret: %0, 第一个参数
  • a : %1, 第二个参数
  • b : %2, 第三个参数

参数名

1
2
3
4
5
6
asm volatile(
"add %[result], %[a], %[b]"

: [result] "=r" (ret)
: [a] "r" (a), [b] "r" (b)
);

传入参数不依赖参数序列

示例

4x4矩阵乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <sys/time.h>

#if __aarch64__
#include <arm_neon.h>
#endif

static void dump(uint16_t **x)
{
int i, j;
uint16_t *xx = (uint16_t *)x;

printf("%s:\n", __func__);

for(i = 0; i < 4; i++) {
for(j = 0; j < 4; j++) {
printf("%3d ", *(xx + (i << 2) + j));
}

printf("\n");
}
}

static void matrix_mul_c(uint16_t aa[][4], uint16_t bb[][4], uint16_t cc[][4])
{
int i = 0, j = 0;

printf("===> func: %s, line: %d\n", __func__, __LINE__);

for(i = 0; i < 4; i++) {
for(j = 0; j < 4; j++) {
cc[i][j] = aa[i][j] * bb[i][j];
}
}

}

#if __aarch64__
static void matrix_mul_neon(uint16_t **aa, uint16_t **bb, uint16_t **cc)
{
printf("===> func: %s, line: %d\n", __func__, __LINE__);
#if 1
uint16_t (*a)[4] = (uint16_t (*)[4])aa;
uint16_t (*b)[4] = (uint16_t (*)[4])bb;
uint16_t (*c)[4] = (uint16_t (*)[4])cc;

printf("aaaaaaaa\n");
asm("nop");
asm("nop");
asm("nop");
asm("nop");
uint16x4_t _cc0;
uint16x4_t _cc1;
uint16x4_t _cc2;
uint16x4_t _cc3;

uint16x4_t _aa0 = vld1_u16((uint16_t*)a[0]);
uint16x4_t _aa1 = vld1_u16((uint16_t*)a[1]);
uint16x4_t _aa2 = vld1_u16((uint16_t*)a[2]);
uint16x4_t _aa3 = vld1_u16((uint16_t*)a[3]);

uint16x4_t _bb0 = vld1_u16((uint16_t*)b[0]);
uint16x4_t _bb1 = vld1_u16((uint16_t*)b[1]);
uint16x4_t _bb2 = vld1_u16((uint16_t*)b[2]);
uint16x4_t _bb3 = vld1_u16((uint16_t*)b[3]);

_cc0 = vmul_u16(_aa0, _bb0);
_cc1 = vmul_u16(_aa1, _bb1);
_cc2 = vmul_u16(_aa2, _bb2);
_cc3 = vmul_u16(_aa3, _bb3);

vst1_u16((uint16_t*)c[0], _cc0);
vst1_u16((uint16_t*)c[1], _cc1);
vst1_u16((uint16_t*)c[2], _cc2);
vst1_u16((uint16_t*)c[3], _cc3);
asm("nop");
asm("nop");
asm("nop");
asm("nop");
#else
printf("bbbbbbbb\n");
int i = 0;
uint16x4_t _aa[4], _bb[4], _cc[4];
uint16_t *a = (uint16_t*)aa;
uint16_t *b = (uint16_t*)bb;
uint16_t *c = (uint16_t*)cc;

for(i = 0; i < 4; i++) {
_aa[i] = vld1_u16(a + (i << 2));
_bb[i] = vld1_u16(b + (i << 2));
_cc[i] = vmul_u16(_aa[i], _bb[i]);
vst1_u16(c + (i << 2), _cc[i]);
}

#endif
}

static void matrix_mul_asm(uint16_t **aa, uint16_t **bb, uint16_t **cc)
{
printf("===> func: %s, line: %d\n", __func__, __LINE__);

uint16_t *a = (uint16_t*)aa;
uint16_t *b = (uint16_t*)bb;
uint16_t *c = (uint16_t*)cc;

#if 0
asm volatile(
"ldr d3, [%0, #0] \n\t"
"ldr d2, [%0, #8] \n\t"
"ldr d1, [%0, #16] \n\t"
"ldr d0, [%0, #24] \n\t"

"ldr d7, [%1, #0] \n\t"
"ldr d6, [%1, #8] \n\t"
"ldr d5, [%1, #16] \n\t"
"ldr d4, [%1, #24] \n\t"

"mul v3.4h, v3.4h, v7.4h \n\t"
"mul v2.4h, v2.4h, v6.4h \n\t"
"mul v1.4h, v1.4h, v5.4h \n\t"
"mul v0.4h, v0.4h, v4.4h \n\t"

//"add v3.4h, v3.4h, v7.4h \n\t"
//"add v2.4h, v2.4h, v6.4h \n\t"
//"add v1.4h, v1.4h, v5.4h \n\t"
//"add v0.4h, v0.4h, v4.4h \n\t"

"str d3, [%2,#0] \n\t"
"str d2, [%2,#8] \n\t"
"str d1, [%2,#16] \n\t"
"str d0, [%2,#24] \n\t"

: "+r"(a), //%0
"+r"(b), //%1
"+r"(c) //%2
:
: "cc", "memory", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7"
);
#else
// test, OK
asm("nop");
asm("nop");
asm("nop");
asm("nop");
asm("nop");
asm volatile(
//"ld4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n\t"
"ld4 {v0.4h-v3.4h}, [%0] \n\t"
"ld4 {v4.4h, v5.4h, v6.4h, v7.4h}, [%1] \n\t"

"mul v3.4h, v3.4h, v7.4h \n\t"
"mul v2.4h, v2.4h, v6.4h \n\t"
"mul v1.4h, v1.4h, v5.4h \n\t"
"mul v0.4h, v0.4h, v4.4h \n\t"

"st4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2] \n\t"

: "+r"(a), //%0
"+r"(b), //%1
"+r"(c) //%2
:
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
);
asm("nop");
asm("nop");
asm("nop");
asm("nop");
asm("nop");
#endif
}
#endif

int main(int argc, const char *argv[])
{
uint16_t aa[4][4] = {
{1, 2, 3, 4},
{5, 6, 7, 8},
{3, 6, 8, 1},
{2, 6, 7, 1}
};

uint16_t bb[4][4] = {
{1, 3, 5, 7},
{2, 4, 6, 8},
{2, 5, 7, 9},
{5, 2, 7, 1}
};

uint16_t cc[4][4] = {0};
int i, j;
struct timeval tv;
long long start_us = 0, end_us = 0;

dump((uint16_t **)aa);
dump((uint16_t **)bb);
dump((uint16_t **)cc);

/* ******** C **********/
gettimeofday(&tv, NULL);
start_us = tv.tv_sec + tv.tv_usec;

matrix_mul_c(aa, bb, cc);

gettimeofday(&tv, NULL);
end_us = tv.tv_sec + tv.tv_usec;
printf("aa[][]*bb[][] C time %lld us\n", end_us - start_us);
dump((uint16_t **)cc);

#if __aarch64__
/* ******** NEON **********/
memset(cc, 0, sizeof(uint16_t) * 4 * 4);
gettimeofday(&tv, NULL);
start_us = tv.tv_sec + tv.tv_usec;

matrix_mul_neon((uint16_t **)aa, (uint16_t **)bb, (uint16_t **)cc);

gettimeofday(&tv, NULL);
end_us = tv.tv_sec + tv.tv_usec;
printf("aa[][]*bb[][] neon time %lld us\n", end_us - start_us);
dump((uint16_t **)cc);

/* ******** asm **********/
memset(cc, 0, sizeof(uint16_t) * 4 * 4);
gettimeofday(&tv, NULL);
start_us = tv.tv_sec + tv.tv_usec;

matrix_mul_asm((uint16_t **)aa, (uint16_t **)bb, (uint16_t **)cc);

gettimeofday(&tv, NULL);
end_us = tv.tv_sec + tv.tv_usec;
printf("aa[][]*bb[][] asm time %lld us\n", end_us - start_us);
dump((uint16_t **)cc);
#endif

return 0;
}
1
aarch64-linux-gcc -O3  matrix_4x4_mul.c

gcc –march=armv8-a [input file] -o [output file]

8x8矩阵乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
static void matrix_mul_asm(uint16_t **aa, uint16_t **bb, uint16_t **cc)
{
printf("===> func: %s, line: %d\n", __func__, __LINE__);

uint16_t *a = (uint16_t*)aa;
uint16_t *b = (uint16_t*)bb;
uint16_t *c = (uint16_t*)cc;

asm volatile(
"ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n\t"
"ld4 {v8.8h, v9.8h, v10.8h, v11.8h}, [%1] \n\t"

"mul v0.8h, v0.8h, v8.8h \n\t"
"mul v1.8h, v1.8h, v9.8h \n\t"
"mul v2.8h, v2.8h, v10.8h \n\t"
"mul v3.8h, v3.8h, v11.8h \n\t"

"st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%2] \n\t"


"add x1, %0, #64 \n\t"
"add x2, %1, #64 \n\t"
"add x3, %2, #64 \n\t"

//"ld4 {v4.8h-v7.8h}, [x1] \n\t"
"ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1] \n\t"
"ld4 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2] \n\t"

"mul v4.8h, v4.8h, v12.8h \n\t"
"mul v5.8h, v5.8h, v13.8h \n\t"
"mul v6.8h, v6.8h, v14.8h \n\t"
"mul v7.8h, v7.8h, v15.8h \n\t"

"st4 {v4.8h, v5.8h, v6.8h, v7.8h}, [x3] \n\t"

: "+r"(a), //%0
"+r"(b), //%1
"+r"(c) //%2
:
: "cc", "memory", "x1", "x2", "x3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
);
}

内嵌汇编实现方式8x8

参考

-------------本文结束感谢您的阅读-------------
  • 本文作者: Winddoing
  • 本文链接: https://winddoing.github.io/post/13631.html
  • 作者声明: 本博文为个人笔记, 由于个人能力有限,难免出现错误,欢迎大家批评指正。
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!