1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #include "./FbgemmFP16UKernelsAvx2.h" |
8 | #include "./InlineAsmDefines.h" |
9 | |
10 | namespace fbgemm { |
11 | |
12 | void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
13 | asm volatile( |
14 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
15 | "mov %[gp], %%r14\t\n" |
16 | ".intel_syntax noprefix\t\n" |
17 | #else |
18 | "mov r14, %[gp]\t\n" |
19 | #endif |
20 | |
21 | // Copy parameters |
22 | // k |
23 | "mov r8, [r14 + 0]\t\n" |
24 | "dec r8\t\n" |
25 | // A |
26 | "mov r9, [r14 + 8]\t\n" |
27 | // B |
28 | "mov r10, [r14 + 16]\t\n" |
29 | // beta |
30 | "lea r15, [r14 + 24]\t\n" |
31 | // C |
32 | "mov r12, [r14 + 32]\t\n" |
33 | // ldc |
34 | "mov r13, [r14 + 40]\t\n" |
35 | // b_block_cols |
36 | "mov rdi, [r14 + 48]\t\n" |
37 | // b_block_size |
38 | "mov rsi, [r14 + 56]\t\n" |
39 | |
40 | // Make copies of A and C |
41 | "mov rax, r9\t\n" |
42 | "mov rcx, r12\t\n" |
43 | |
44 | "xor ebx, ebx\t\n" |
45 | "loop_outter%=:\t\n" |
46 | "mov r14, r8\t\n" |
47 | |
48 | "vxorps ymm5, ymm5, ymm5\t\n" |
49 | "vxorps ymm6, ymm6, ymm6\t\n" |
50 | "vxorps ymm7, ymm7, ymm7\t\n" |
51 | "vxorps ymm8, ymm8, ymm8\t\n" |
52 | "vxorps ymm9, ymm9, ymm9\t\n" |
53 | "vxorps ymm10, ymm10, ymm10\t\n" |
54 | |
55 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
56 | "vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n" |
57 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" |
58 | "vxorps xmm0, xmm0, xmm0\t\n" |
59 | "vcomiss xmm15, xmm0\t\n" |
60 | "jz zero_regs%=\t\n" |
61 | |
62 | // Setup values with beta multiplication |
63 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
64 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
65 | "test r14,r14\t\n" |
66 | "jz skip_preload%=\t\n" |
67 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
68 | "skip_preload%=:\t\n" |
69 | "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" |
70 | "vfmadd231ps ymm0,ymm3,ymm2\t\n" |
71 | "vfmadd231ps ymm1,ymm4,ymm2\t\n" |
72 | "test r14,r14\t\n" |
73 | "jnz loop_inner_start%=\t\n" |
74 | "add r10,32\t\n" |
75 | "jmp dump_C%=\t\n" |
76 | |
77 | "zero_regs%=:\t\n" |
78 | |
79 | "test r14,r14\t\n" |
80 | "jz skip_preload_b_zero%=\t\n" |
81 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
82 | "skip_preload_b_zero%=:\t\n" |
83 | "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" |
84 | "vmulps ymm0,ymm3,ymm2\t\n" |
85 | "vmulps ymm1,ymm4,ymm2\t\n" |
86 | "test r14,r14\t\n" |
87 | "jnz loop_inner_start%=\t\n" |
88 | "add r10,32\t\n" |
89 | "jmp dump_C%=\t\n" |
90 | |
91 | "loop_inner_start%=:\t\n" |
92 | "add r9,4\t\n" |
93 | "add r10,32\t\n" |
94 | "cmp r14,4\t\n" |
95 | |
96 | "jle loop_inner_end%=\t\n" |
97 | |
98 | "loop_inner%=:\t\n" |
99 | |
100 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" |
101 | "vcvtph2ps ymm3,XMMWORD PTR [r10 + 32]\t\n" |
102 | "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" |
103 | "vfmadd231ps ymm0,ymm15,ymm2\t\n" |
104 | "vfmadd231ps ymm1,ymm4,ymm2\t\n" |
105 | |
106 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 48]\t\n" |
107 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 64]\t\n" |
108 | "vbroadcastss ymm2,DWORD PTR [r9+4]\t\n" |
109 | "vfmadd231ps ymm5,ymm3,ymm2\t\n" |
110 | "vfmadd231ps ymm6,ymm4,ymm2\t\n" |
111 | |
112 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 80]\t\n" |
113 | "vcvtph2ps ymm3,XMMWORD PTR [r10 + 96]\t\n" |
114 | "vbroadcastss ymm2,DWORD PTR [r9+8]\t\n" |
115 | "vfmadd231ps ymm7,ymm15,ymm2\t\n" |
116 | "vfmadd231ps ymm8,ymm4,ymm2\t\n" |
117 | |
118 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 112]\t\n" |
119 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 128]\t\n" |
120 | "vbroadcastss ymm2,DWORD PTR [r9+12]\t\n" |
121 | "vfmadd231ps ymm9,ymm3,ymm2\t\n" |
122 | "vfmadd231ps ymm10,ymm4,ymm2\t\n" |
123 | |
124 | "next_inner%=:\t\n" |
125 | "add r9,16\t\n" |
126 | "add r10,128\t\n" |
127 | "sub r14,4\t\n" |
128 | |
129 | "cmp r14, 4\t\n" |
130 | "jg loop_inner%=\t\n" |
131 | "loop_inner_end%=:\t\n" |
132 | |
133 | "cmp r14, 0\t\n" |
134 | "jz loop_tail%=\t\n" |
135 | |
136 | "vcvtph2ps ymm3,XMMWORD PTR [r10]\t\n" |
137 | "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" |
138 | "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" |
139 | "vfmadd231ps ymm0,ymm3,ymm2\t\n" |
140 | "vfmadd231ps ymm1,ymm4,ymm2\t\n" |
141 | "add r9,4\t\n" |
142 | "add r10,32\t\n" |
143 | "dec r14\t\n" |
144 | |
145 | "jmp loop_inner_end%=\t\n" |
146 | |
147 | "loop_tail%=:\t\n" |
148 | "vaddps ymm0, ymm0, ymm5\t\n" |
149 | "vaddps ymm0, ymm0, ymm7\t\n" |
150 | "vaddps ymm0, ymm0, ymm9\t\n" |
151 | "vaddps ymm1, ymm1, ymm6\t\n" |
152 | "vaddps ymm1, ymm1, ymm8\t\n" |
153 | "vaddps ymm1, ymm1, ymm10\t\n" |
154 | |
155 | // Dump C |
156 | "dump_C%=:\t\n" |
157 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
158 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
159 | |
160 | // next outer iteration |
161 | "add rcx, 64\t\n" |
162 | "mov r12, rcx\t\n" |
163 | "mov r9, rax\t\n" |
164 | "inc rbx\t\n" |
165 | "cmp rbx, rdi\t\n" |
166 | "jl loop_outter%=\t\n" |
167 | : |
168 | : [gp] "rm" (gp) |
169 | : "r8" , |
170 | "r9" , |
171 | "r10" , |
172 | "r11" , |
173 | "r13" , |
174 | "r14" , |
175 | "rax" , |
176 | "rcx" , |
177 | "rsi" , |
178 | "rdi" , |
179 | "rbx" , |
180 | "r12" , |
181 | "r15" , |
182 | "memory" ); |
183 | } |
184 | void NOINLINE gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
185 | asm volatile( |
186 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
187 | "mov %[gp], %%r14\t\n" |
188 | ".intel_syntax noprefix\t\n" |
189 | #else |
190 | "mov r14, %[gp]\t\n" |
191 | #endif |
192 | |
193 | // Copy parameters |
194 | // k |
195 | "mov r8, [r14 + 0]\t\n" |
196 | "dec r8\t\n" |
197 | // A |
198 | "mov r9, [r14 + 8]\t\n" |
199 | // B |
200 | "mov r10, [r14 + 16]\t\n" |
201 | // beta |
202 | "lea r15, [r14 + 24]\t\n" |
203 | // C |
204 | "mov r12, [r14 + 32]\t\n" |
205 | // ldc |
206 | "mov r13, [r14 + 40]\t\n" |
207 | // b_block_cols |
208 | "mov rdi, [r14 + 48]\t\n" |
209 | // b_block_size |
210 | "mov rsi, [r14 + 56]\t\n" |
211 | |
212 | // Make copies of A and C |
213 | "mov rax, r9\t\n" |
214 | "mov rcx, r12\t\n" |
215 | |
216 | "xor ebx, ebx\t\n" |
217 | "loop_outter%=:\t\n" |
218 | "mov r14, r8\t\n" |
219 | |
220 | "vxorps ymm7, ymm7, ymm7\t\n" |
221 | "vxorps ymm8, ymm8, ymm8\t\n" |
222 | "vxorps ymm9, ymm9, ymm9\t\n" |
223 | "vxorps ymm10, ymm10, ymm10\t\n" |
224 | |
225 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
226 | "vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n" |
227 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" |
228 | "vxorps xmm0, xmm0, xmm0\t\n" |
229 | "vcomiss xmm15, xmm0\t\n" |
230 | "jz zero_regs%=\t\n" |
231 | |
232 | // Setup values with beta multiplication |
233 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
234 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
235 | "add r12, r13\t\n" |
236 | "vmulps ymm2, ymm15, [r12 + 0]\t\n" |
237 | "vmulps ymm3, ymm15, [r12 + 32]\t\n" |
238 | "test r14,r14\t\n" |
239 | "jz skip_preload%=\t\n" |
240 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
241 | "skip_preload%=:\t\n" |
242 | "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" |
243 | "vfmadd231ps ymm0,ymm5,ymm4\t\n" |
244 | "vfmadd231ps ymm1,ymm6,ymm4\t\n" |
245 | "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" |
246 | "vfmadd231ps ymm2,ymm5,ymm4\t\n" |
247 | "vfmadd231ps ymm3,ymm6,ymm4\t\n" |
248 | "mov r12, rcx\t\n" |
249 | "test r14,r14\t\n" |
250 | "jnz loop_inner_start%=\t\n" |
251 | "add r10,32\t\n" |
252 | "jmp dump_C%=\t\n" |
253 | |
254 | "zero_regs%=:\t\n" |
255 | |
256 | "test r14,r14\t\n" |
257 | "jz skip_preload_b_zero%=\t\n" |
258 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
259 | "skip_preload_b_zero%=:\t\n" |
260 | "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" |
261 | "vmulps ymm0,ymm5,ymm4\t\n" |
262 | "vmulps ymm1,ymm6,ymm4\t\n" |
263 | "add r12, r13\t\n" |
264 | "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" |
265 | "vmulps ymm2,ymm5,ymm4\t\n" |
266 | "vmulps ymm3,ymm6,ymm4\t\n" |
267 | "mov r12, rcx\t\n" |
268 | "test r14,r14\t\n" |
269 | "jnz loop_inner_start%=\t\n" |
270 | "add r10,32\t\n" |
271 | "jmp dump_C%=\t\n" |
272 | |
273 | "loop_inner_start%=:\t\n" |
274 | "add r9,8\t\n" |
275 | "add r10,32\t\n" |
276 | "cmp r14,4\t\n" |
277 | |
278 | "jle loop_inner_end%=\t\n" |
279 | |
280 | "loop_inner%=:\t\n" |
281 | |
282 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" |
283 | "vcvtph2ps ymm5,XMMWORD PTR [r10 + 32]\t\n" |
284 | "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" |
285 | "vfmadd231ps ymm0,ymm15,ymm4\t\n" |
286 | "vfmadd231ps ymm1,ymm6,ymm4\t\n" |
287 | "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" |
288 | "vfmadd231ps ymm2,ymm15,ymm4\t\n" |
289 | "vfmadd231ps ymm3,ymm6,ymm4\t\n" |
290 | |
291 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 48]\t\n" |
292 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 64]\t\n" |
293 | "vbroadcastss ymm4,DWORD PTR [r9+8]\t\n" |
294 | "vfmadd231ps ymm7,ymm5,ymm4\t\n" |
295 | "vfmadd231ps ymm8,ymm6,ymm4\t\n" |
296 | "vbroadcastss ymm4,DWORD PTR [r9+12]\t\n" |
297 | "vfmadd231ps ymm9,ymm5,ymm4\t\n" |
298 | "vfmadd231ps ymm10,ymm6,ymm4\t\n" |
299 | |
300 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 80]\t\n" |
301 | "vcvtph2ps ymm5,XMMWORD PTR [r10 + 96]\t\n" |
302 | "vbroadcastss ymm4,DWORD PTR [r9+16]\t\n" |
303 | "vfmadd231ps ymm0,ymm15,ymm4\t\n" |
304 | "vfmadd231ps ymm1,ymm6,ymm4\t\n" |
305 | "vbroadcastss ymm4,DWORD PTR [r9+20]\t\n" |
306 | "vfmadd231ps ymm2,ymm15,ymm4\t\n" |
307 | "vfmadd231ps ymm3,ymm6,ymm4\t\n" |
308 | |
309 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 112]\t\n" |
310 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 128]\t\n" |
311 | "vbroadcastss ymm4,DWORD PTR [r9+24]\t\n" |
312 | "vfmadd231ps ymm7,ymm5,ymm4\t\n" |
313 | "vfmadd231ps ymm8,ymm6,ymm4\t\n" |
314 | "vbroadcastss ymm4,DWORD PTR [r9+28]\t\n" |
315 | "vfmadd231ps ymm9,ymm5,ymm4\t\n" |
316 | "vfmadd231ps ymm10,ymm6,ymm4\t\n" |
317 | |
318 | "next_inner%=:\t\n" |
319 | "add r9,32\t\n" |
320 | "add r10,128\t\n" |
321 | "sub r14,4\t\n" |
322 | |
323 | "cmp r14, 4\t\n" |
324 | "jg loop_inner%=\t\n" |
325 | "loop_inner_end%=:\t\n" |
326 | |
327 | "cmp r14, 0\t\n" |
328 | "jz loop_tail%=\t\n" |
329 | |
330 | "vcvtph2ps ymm5,XMMWORD PTR [r10]\t\n" |
331 | "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" |
332 | "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" |
333 | "vfmadd231ps ymm0,ymm5,ymm4\t\n" |
334 | "vfmadd231ps ymm1,ymm6,ymm4\t\n" |
335 | "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" |
336 | "vfmadd231ps ymm2,ymm5,ymm4\t\n" |
337 | "vfmadd231ps ymm3,ymm6,ymm4\t\n" |
338 | "add r9,8\t\n" |
339 | "add r10,32\t\n" |
340 | "dec r14\t\n" |
341 | |
342 | "jmp loop_inner_end%=\t\n" |
343 | |
344 | "loop_tail%=:\t\n" |
345 | "vaddps ymm0, ymm0, ymm7\t\n" |
346 | "vaddps ymm1, ymm1, ymm8\t\n" |
347 | "vaddps ymm2, ymm2, ymm9\t\n" |
348 | "vaddps ymm3, ymm3, ymm10\t\n" |
349 | |
350 | // Dump C |
351 | "dump_C%=:\t\n" |
352 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
353 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
354 | "add r12, r13\t\n" |
355 | "vmovups ymmword PTR [r12 + 0], ymm2\t\n" |
356 | "vmovups ymmword PTR [r12 + 32], ymm3\t\n" |
357 | |
358 | // next outer iteration |
359 | "add rcx, 64\t\n" |
360 | "mov r12, rcx\t\n" |
361 | "mov r9, rax\t\n" |
362 | "inc rbx\t\n" |
363 | "cmp rbx, rdi\t\n" |
364 | "jl loop_outter%=\t\n" |
365 | : |
366 | : [gp] "rm" (gp) |
367 | : "r8" , |
368 | "r9" , |
369 | "r10" , |
370 | "r11" , |
371 | "r13" , |
372 | "r14" , |
373 | "rax" , |
374 | "rcx" , |
375 | "rsi" , |
376 | "rdi" , |
377 | "rbx" , |
378 | "r12" , |
379 | "r15" , |
380 | "memory" ); |
381 | } |
382 | void NOINLINE gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
383 | asm volatile( |
384 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
385 | "mov %[gp], %%r14\t\n" |
386 | ".intel_syntax noprefix\t\n" |
387 | #else |
388 | "mov r14, %[gp]\t\n" |
389 | #endif |
390 | |
391 | // Copy parameters |
392 | // k |
393 | "mov r8, [r14 + 0]\t\n" |
394 | "dec r8\t\n" |
395 | // A |
396 | "mov r9, [r14 + 8]\t\n" |
397 | // B |
398 | "mov r10, [r14 + 16]\t\n" |
399 | // beta |
400 | "lea r15, [r14 + 24]\t\n" |
401 | // C |
402 | "mov r12, [r14 + 32]\t\n" |
403 | // ldc |
404 | "mov r13, [r14 + 40]\t\n" |
405 | // b_block_cols |
406 | "mov rdi, [r14 + 48]\t\n" |
407 | // b_block_size |
408 | "mov rsi, [r14 + 56]\t\n" |
409 | |
410 | // Make copies of A and C |
411 | "mov rax, r9\t\n" |
412 | "mov rcx, r12\t\n" |
413 | |
414 | "xor ebx, ebx\t\n" |
415 | "loop_outter%=:\t\n" |
416 | "mov r14, r8\t\n" |
417 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
418 | "vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n" |
419 | "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" |
420 | "vxorps xmm0, xmm0, xmm0\t\n" |
421 | "vcomiss xmm15, xmm0\t\n" |
422 | "jz zero_regs%=\t\n" |
423 | |
424 | // Setup values with beta multiplication |
425 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
426 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
427 | "add r12, r13\t\n" |
428 | "vmulps ymm2, ymm15, [r12 + 0]\t\n" |
429 | "vmulps ymm3, ymm15, [r12 + 32]\t\n" |
430 | "add r12, r13\t\n" |
431 | "vmulps ymm4, ymm15, [r12 + 0]\t\n" |
432 | "vmulps ymm5, ymm15, [r12 + 32]\t\n" |
433 | "test r14,r14\t\n" |
434 | "jz skip_preload%=\t\n" |
435 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
436 | "skip_preload%=:\t\n" |
437 | "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" |
438 | "vfmadd231ps ymm0,ymm7,ymm6\t\n" |
439 | "vfmadd231ps ymm1,ymm8,ymm6\t\n" |
440 | "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" |
441 | "vfmadd231ps ymm2,ymm7,ymm6\t\n" |
442 | "vfmadd231ps ymm3,ymm8,ymm6\t\n" |
443 | "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" |
444 | "vfmadd231ps ymm4,ymm7,ymm6\t\n" |
445 | "vfmadd231ps ymm5,ymm8,ymm6\t\n" |
446 | "mov r12, rcx\t\n" |
447 | "test r14,r14\t\n" |
448 | "jnz next_inner%=\t\n" |
449 | "add r10,32\t\n" |
450 | "jmp dump_C%=\t\n" |
451 | |
452 | "zero_regs%=:\t\n" |
453 | |
454 | "test r14,r14\t\n" |
455 | "jz skip_preload_b_zero%=\t\n" |
456 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
457 | "skip_preload_b_zero%=:\t\n" |
458 | "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" |
459 | "vmulps ymm0,ymm7,ymm6\t\n" |
460 | "vmulps ymm1,ymm8,ymm6\t\n" |
461 | "add r12, r13\t\n" |
462 | "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" |
463 | "vmulps ymm2,ymm7,ymm6\t\n" |
464 | "vmulps ymm3,ymm8,ymm6\t\n" |
465 | "add r12, r13\t\n" |
466 | "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" |
467 | "vmulps ymm4,ymm7,ymm6\t\n" |
468 | "vmulps ymm5,ymm8,ymm6\t\n" |
469 | "mov r12, rcx\t\n" |
470 | "test r14,r14\t\n" |
471 | "jnz next_inner%=\t\n" |
472 | "add r10,32\t\n" |
473 | "jmp dump_C%=\t\n" |
474 | |
475 | "loop_inner%=:\t\n" |
476 | |
477 | "vmovaps ymm7,ymm15\t\n" |
478 | "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" |
479 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
480 | "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" |
481 | "vfmadd231ps ymm0,ymm7,ymm6\t\n" |
482 | "vfmadd231ps ymm1,ymm8,ymm6\t\n" |
483 | "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" |
484 | "vfmadd231ps ymm2,ymm7,ymm6\t\n" |
485 | "vfmadd231ps ymm3,ymm8,ymm6\t\n" |
486 | "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" |
487 | "vfmadd231ps ymm4,ymm7,ymm6\t\n" |
488 | "vfmadd231ps ymm5,ymm8,ymm6\t\n" |
489 | |
490 | "next_inner%=:\t\n" |
491 | "add r9,12\t\n" |
492 | "add r10,32\t\n" |
493 | "dec r14\t\n" |
494 | "jnz loop_inner%=\t\n" |
495 | |
496 | "vmovaps ymm7,ymm15\t\n" |
497 | "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" |
498 | "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" |
499 | "vfmadd231ps ymm0,ymm7,ymm6\t\n" |
500 | "vfmadd231ps ymm1,ymm8,ymm6\t\n" |
501 | "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" |
502 | "vfmadd231ps ymm2,ymm7,ymm6\t\n" |
503 | "vfmadd231ps ymm3,ymm8,ymm6\t\n" |
504 | "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" |
505 | "vfmadd231ps ymm4,ymm7,ymm6\t\n" |
506 | "vfmadd231ps ymm5,ymm8,ymm6\t\n" |
507 | "add r9,12\t\n" |
508 | "add r10,32\t\n" |
509 | // Dump C |
510 | "dump_C%=:\t\n" |
511 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
512 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
513 | "add r12, r13\t\n" |
514 | "vmovups ymmword PTR [r12 + 0], ymm2\t\n" |
515 | "vmovups ymmword PTR [r12 + 32], ymm3\t\n" |
516 | "add r12, r13\t\n" |
517 | "vmovups ymmword PTR [r12 + 0], ymm4\t\n" |
518 | "vmovups ymmword PTR [r12 + 32], ymm5\t\n" |
519 | |
520 | // next outer iteration |
521 | "add rcx, 64\t\n" |
522 | "mov r12, rcx\t\n" |
523 | "mov r9, rax\t\n" |
524 | "inc rbx\t\n" |
525 | "cmp rbx, rdi\t\n" |
526 | "jl loop_outter%=\t\n" |
527 | : |
528 | : [gp] "rm" (gp) |
529 | : "r8" , |
530 | "r9" , |
531 | "r10" , |
532 | "r11" , |
533 | "r13" , |
534 | "r14" , |
535 | "rax" , |
536 | "rcx" , |
537 | "rsi" , |
538 | "rdi" , |
539 | "rbx" , |
540 | "r12" , |
541 | "r15" , |
542 | "memory" ); |
543 | } |
544 | void NOINLINE gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
545 | asm volatile( |
546 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
547 | "mov %[gp], %%r14\t\n" |
548 | ".intel_syntax noprefix\t\n" |
549 | #else |
550 | "mov r14, %[gp]\t\n" |
551 | #endif |
552 | |
553 | // Copy parameters |
554 | // k |
555 | "mov r8, [r14 + 0]\t\n" |
556 | "dec r8\t\n" |
557 | // A |
558 | "mov r9, [r14 + 8]\t\n" |
559 | // B |
560 | "mov r10, [r14 + 16]\t\n" |
561 | // beta |
562 | "lea r15, [r14 + 24]\t\n" |
563 | // C |
564 | "mov r12, [r14 + 32]\t\n" |
565 | // ldc |
566 | "mov r13, [r14 + 40]\t\n" |
567 | // b_block_cols |
568 | "mov rdi, [r14 + 48]\t\n" |
569 | // b_block_size |
570 | "mov rsi, [r14 + 56]\t\n" |
571 | |
572 | // Make copies of A and C |
573 | "mov rax, r9\t\n" |
574 | "mov rcx, r12\t\n" |
575 | |
576 | "xor ebx, ebx\t\n" |
577 | "loop_outter%=:\t\n" |
578 | "mov r14, r8\t\n" |
579 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
580 | "vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n" |
581 | "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" |
582 | "vxorps xmm0, xmm0, xmm0\t\n" |
583 | "vcomiss xmm15, xmm0\t\n" |
584 | "jz zero_regs%=\t\n" |
585 | |
586 | // Setup values with beta multiplication |
587 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
588 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
589 | "add r12, r13\t\n" |
590 | "vmulps ymm2, ymm15, [r12 + 0]\t\n" |
591 | "vmulps ymm3, ymm15, [r12 + 32]\t\n" |
592 | "add r12, r13\t\n" |
593 | "vmulps ymm4, ymm15, [r12 + 0]\t\n" |
594 | "vmulps ymm5, ymm15, [r12 + 32]\t\n" |
595 | "add r12, r13\t\n" |
596 | "vmulps ymm6, ymm15, [r12 + 0]\t\n" |
597 | "vmulps ymm7, ymm15, [r12 + 32]\t\n" |
598 | "test r14,r14\t\n" |
599 | "jz skip_preload%=\t\n" |
600 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
601 | "skip_preload%=:\t\n" |
602 | "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" |
603 | "vfmadd231ps ymm0,ymm9,ymm8\t\n" |
604 | "vfmadd231ps ymm1,ymm10,ymm8\t\n" |
605 | "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" |
606 | "vfmadd231ps ymm2,ymm9,ymm8\t\n" |
607 | "vfmadd231ps ymm3,ymm10,ymm8\t\n" |
608 | "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" |
609 | "vfmadd231ps ymm4,ymm9,ymm8\t\n" |
610 | "vfmadd231ps ymm5,ymm10,ymm8\t\n" |
611 | "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" |
612 | "vfmadd231ps ymm6,ymm9,ymm8\t\n" |
613 | "vfmadd231ps ymm7,ymm10,ymm8\t\n" |
614 | "mov r12, rcx\t\n" |
615 | "test r14,r14\t\n" |
616 | "jnz next_inner%=\t\n" |
617 | "add r10,32\t\n" |
618 | "jmp dump_C%=\t\n" |
619 | |
620 | "zero_regs%=:\t\n" |
621 | |
622 | "test r14,r14\t\n" |
623 | "jz skip_preload_b_zero%=\t\n" |
624 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
625 | "skip_preload_b_zero%=:\t\n" |
626 | "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" |
627 | "vmulps ymm0,ymm9,ymm8\t\n" |
628 | "vmulps ymm1,ymm10,ymm8\t\n" |
629 | "add r12, r13\t\n" |
630 | "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" |
631 | "vmulps ymm2,ymm9,ymm8\t\n" |
632 | "vmulps ymm3,ymm10,ymm8\t\n" |
633 | "add r12, r13\t\n" |
634 | "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" |
635 | "vmulps ymm4,ymm9,ymm8\t\n" |
636 | "vmulps ymm5,ymm10,ymm8\t\n" |
637 | "add r12, r13\t\n" |
638 | "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" |
639 | "vmulps ymm6,ymm9,ymm8\t\n" |
640 | "vmulps ymm7,ymm10,ymm8\t\n" |
641 | "mov r12, rcx\t\n" |
642 | "test r14,r14\t\n" |
643 | "jnz next_inner%=\t\n" |
644 | "add r10,32\t\n" |
645 | "jmp dump_C%=\t\n" |
646 | |
647 | "loop_inner%=:\t\n" |
648 | |
649 | "vmovaps ymm9,ymm15\t\n" |
650 | "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" |
651 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
652 | "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" |
653 | "vfmadd231ps ymm0,ymm9,ymm8\t\n" |
654 | "vfmadd231ps ymm1,ymm10,ymm8\t\n" |
655 | "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" |
656 | "vfmadd231ps ymm2,ymm9,ymm8\t\n" |
657 | "vfmadd231ps ymm3,ymm10,ymm8\t\n" |
658 | "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" |
659 | "vfmadd231ps ymm4,ymm9,ymm8\t\n" |
660 | "vfmadd231ps ymm5,ymm10,ymm8\t\n" |
661 | "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" |
662 | "vfmadd231ps ymm6,ymm9,ymm8\t\n" |
663 | "vfmadd231ps ymm7,ymm10,ymm8\t\n" |
664 | |
665 | "next_inner%=:\t\n" |
666 | "add r9,16\t\n" |
667 | "add r10,32\t\n" |
668 | "dec r14\t\n" |
669 | "jnz loop_inner%=\t\n" |
670 | |
671 | "vmovaps ymm9,ymm15\t\n" |
672 | "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" |
673 | "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" |
674 | "vfmadd231ps ymm0,ymm9,ymm8\t\n" |
675 | "vfmadd231ps ymm1,ymm10,ymm8\t\n" |
676 | "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" |
677 | "vfmadd231ps ymm2,ymm9,ymm8\t\n" |
678 | "vfmadd231ps ymm3,ymm10,ymm8\t\n" |
679 | "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" |
680 | "vfmadd231ps ymm4,ymm9,ymm8\t\n" |
681 | "vfmadd231ps ymm5,ymm10,ymm8\t\n" |
682 | "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" |
683 | "vfmadd231ps ymm6,ymm9,ymm8\t\n" |
684 | "vfmadd231ps ymm7,ymm10,ymm8\t\n" |
685 | "add r9,16\t\n" |
686 | "add r10,32\t\n" |
687 | // Dump C |
688 | "dump_C%=:\t\n" |
689 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
690 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
691 | "add r12, r13\t\n" |
692 | "vmovups ymmword PTR [r12 + 0], ymm2\t\n" |
693 | "vmovups ymmword PTR [r12 + 32], ymm3\t\n" |
694 | "add r12, r13\t\n" |
695 | "vmovups ymmword PTR [r12 + 0], ymm4\t\n" |
696 | "vmovups ymmword PTR [r12 + 32], ymm5\t\n" |
697 | "add r12, r13\t\n" |
698 | "vmovups ymmword PTR [r12 + 0], ymm6\t\n" |
699 | "vmovups ymmword PTR [r12 + 32], ymm7\t\n" |
700 | |
701 | // next outer iteration |
702 | "add rcx, 64\t\n" |
703 | "mov r12, rcx\t\n" |
704 | "mov r9, rax\t\n" |
705 | "inc rbx\t\n" |
706 | "cmp rbx, rdi\t\n" |
707 | "jl loop_outter%=\t\n" |
708 | : |
709 | : [gp] "rm" (gp) |
710 | : "r8" , |
711 | "r9" , |
712 | "r10" , |
713 | "r11" , |
714 | "r13" , |
715 | "r14" , |
716 | "rax" , |
717 | "rcx" , |
718 | "rsi" , |
719 | "rdi" , |
720 | "rbx" , |
721 | "r12" , |
722 | "r15" , |
723 | "memory" ); |
724 | } |
725 | void NOINLINE gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
726 | asm volatile( |
727 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
728 | "mov %[gp], %%r14\t\n" |
729 | ".intel_syntax noprefix\t\n" |
730 | #else |
731 | "mov r14, %[gp]\t\n" |
732 | #endif |
733 | |
734 | // Copy parameters |
735 | // k |
736 | "mov r8, [r14 + 0]\t\n" |
737 | "dec r8\t\n" |
738 | // A |
739 | "mov r9, [r14 + 8]\t\n" |
740 | // B |
741 | "mov r10, [r14 + 16]\t\n" |
742 | // beta |
743 | "lea r15, [r14 + 24]\t\n" |
744 | // C |
745 | "mov r12, [r14 + 32]\t\n" |
746 | // ldc |
747 | "mov r13, [r14 + 40]\t\n" |
748 | // b_block_cols |
749 | "mov rdi, [r14 + 48]\t\n" |
750 | // b_block_size |
751 | "mov rsi, [r14 + 56]\t\n" |
752 | |
753 | // Make copies of A and C |
754 | "mov rax, r9\t\n" |
755 | "mov rcx, r12\t\n" |
756 | |
757 | "xor ebx, ebx\t\n" |
758 | "loop_outter%=:\t\n" |
759 | "mov r14, r8\t\n" |
760 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
761 | "vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n" |
762 | "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" |
763 | "vxorps xmm0, xmm0, xmm0\t\n" |
764 | "vcomiss xmm15, xmm0\t\n" |
765 | "jz zero_regs%=\t\n" |
766 | |
767 | // Setup values with beta multiplication |
768 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
769 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
770 | "add r12, r13\t\n" |
771 | "vmulps ymm2, ymm15, [r12 + 0]\t\n" |
772 | "vmulps ymm3, ymm15, [r12 + 32]\t\n" |
773 | "add r12, r13\t\n" |
774 | "vmulps ymm4, ymm15, [r12 + 0]\t\n" |
775 | "vmulps ymm5, ymm15, [r12 + 32]\t\n" |
776 | "add r12, r13\t\n" |
777 | "vmulps ymm6, ymm15, [r12 + 0]\t\n" |
778 | "vmulps ymm7, ymm15, [r12 + 32]\t\n" |
779 | "add r12, r13\t\n" |
780 | "vmulps ymm8, ymm15, [r12 + 0]\t\n" |
781 | "vmulps ymm9, ymm15, [r12 + 32]\t\n" |
782 | "test r14,r14\t\n" |
783 | "jz skip_preload%=\t\n" |
784 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
785 | "skip_preload%=:\t\n" |
786 | "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" |
787 | "vfmadd231ps ymm0,ymm11,ymm10\t\n" |
788 | "vfmadd231ps ymm1,ymm12,ymm10\t\n" |
789 | "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" |
790 | "vfmadd231ps ymm2,ymm11,ymm10\t\n" |
791 | "vfmadd231ps ymm3,ymm12,ymm10\t\n" |
792 | "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" |
793 | "vfmadd231ps ymm4,ymm11,ymm10\t\n" |
794 | "vfmadd231ps ymm5,ymm12,ymm10\t\n" |
795 | "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" |
796 | "vfmadd231ps ymm6,ymm11,ymm10\t\n" |
797 | "vfmadd231ps ymm7,ymm12,ymm10\t\n" |
798 | "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" |
799 | "vfmadd231ps ymm8,ymm11,ymm10\t\n" |
800 | "vfmadd231ps ymm9,ymm12,ymm10\t\n" |
801 | "mov r12, rcx\t\n" |
802 | "test r14,r14\t\n" |
803 | "jnz next_inner%=\t\n" |
804 | "add r10,32\t\n" |
805 | "jmp dump_C%=\t\n" |
806 | |
807 | "zero_regs%=:\t\n" |
808 | |
809 | "test r14,r14\t\n" |
810 | "jz skip_preload_b_zero%=\t\n" |
811 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
812 | "skip_preload_b_zero%=:\t\n" |
813 | "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" |
814 | "vmulps ymm0,ymm11,ymm10\t\n" |
815 | "vmulps ymm1,ymm12,ymm10\t\n" |
816 | "add r12, r13\t\n" |
817 | "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" |
818 | "vmulps ymm2,ymm11,ymm10\t\n" |
819 | "vmulps ymm3,ymm12,ymm10\t\n" |
820 | "add r12, r13\t\n" |
821 | "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" |
822 | "vmulps ymm4,ymm11,ymm10\t\n" |
823 | "vmulps ymm5,ymm12,ymm10\t\n" |
824 | "add r12, r13\t\n" |
825 | "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" |
826 | "vmulps ymm6,ymm11,ymm10\t\n" |
827 | "vmulps ymm7,ymm12,ymm10\t\n" |
828 | "add r12, r13\t\n" |
829 | "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" |
830 | "vmulps ymm8,ymm11,ymm10\t\n" |
831 | "vmulps ymm9,ymm12,ymm10\t\n" |
832 | "mov r12, rcx\t\n" |
833 | "test r14,r14\t\n" |
834 | "jnz next_inner%=\t\n" |
835 | "add r10,32\t\n" |
836 | "jmp dump_C%=\t\n" |
837 | |
838 | "loop_inner%=:\t\n" |
839 | |
840 | "vmovaps ymm11,ymm15\t\n" |
841 | "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" |
842 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
843 | "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" |
844 | "vfmadd231ps ymm0,ymm11,ymm10\t\n" |
845 | "vfmadd231ps ymm1,ymm12,ymm10\t\n" |
846 | "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" |
847 | "vfmadd231ps ymm2,ymm11,ymm10\t\n" |
848 | "vfmadd231ps ymm3,ymm12,ymm10\t\n" |
849 | "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" |
850 | "vfmadd231ps ymm4,ymm11,ymm10\t\n" |
851 | "vfmadd231ps ymm5,ymm12,ymm10\t\n" |
852 | "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" |
853 | "vfmadd231ps ymm6,ymm11,ymm10\t\n" |
854 | "vfmadd231ps ymm7,ymm12,ymm10\t\n" |
855 | "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" |
856 | "vfmadd231ps ymm8,ymm11,ymm10\t\n" |
857 | "vfmadd231ps ymm9,ymm12,ymm10\t\n" |
858 | |
859 | "next_inner%=:\t\n" |
860 | "add r9,20\t\n" |
861 | "add r10,32\t\n" |
862 | "dec r14\t\n" |
863 | "jnz loop_inner%=\t\n" |
864 | |
865 | "vmovaps ymm11,ymm15\t\n" |
866 | "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" |
867 | "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" |
868 | "vfmadd231ps ymm0,ymm11,ymm10\t\n" |
869 | "vfmadd231ps ymm1,ymm12,ymm10\t\n" |
870 | "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" |
871 | "vfmadd231ps ymm2,ymm11,ymm10\t\n" |
872 | "vfmadd231ps ymm3,ymm12,ymm10\t\n" |
873 | "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" |
874 | "vfmadd231ps ymm4,ymm11,ymm10\t\n" |
875 | "vfmadd231ps ymm5,ymm12,ymm10\t\n" |
876 | "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" |
877 | "vfmadd231ps ymm6,ymm11,ymm10\t\n" |
878 | "vfmadd231ps ymm7,ymm12,ymm10\t\n" |
879 | "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" |
880 | "vfmadd231ps ymm8,ymm11,ymm10\t\n" |
881 | "vfmadd231ps ymm9,ymm12,ymm10\t\n" |
882 | "add r9,20\t\n" |
883 | "add r10,32\t\n" |
884 | // Dump C |
885 | "dump_C%=:\t\n" |
886 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
887 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
888 | "add r12, r13\t\n" |
889 | "vmovups ymmword PTR [r12 + 0], ymm2\t\n" |
890 | "vmovups ymmword PTR [r12 + 32], ymm3\t\n" |
891 | "add r12, r13\t\n" |
892 | "vmovups ymmword PTR [r12 + 0], ymm4\t\n" |
893 | "vmovups ymmword PTR [r12 + 32], ymm5\t\n" |
894 | "add r12, r13\t\n" |
895 | "vmovups ymmword PTR [r12 + 0], ymm6\t\n" |
896 | "vmovups ymmword PTR [r12 + 32], ymm7\t\n" |
897 | "add r12, r13\t\n" |
898 | "vmovups ymmword PTR [r12 + 0], ymm8\t\n" |
899 | "vmovups ymmword PTR [r12 + 32], ymm9\t\n" |
900 | |
901 | // next outer iteration |
902 | "add rcx, 64\t\n" |
903 | "mov r12, rcx\t\n" |
904 | "mov r9, rax\t\n" |
905 | "inc rbx\t\n" |
906 | "cmp rbx, rdi\t\n" |
907 | "jl loop_outter%=\t\n" |
908 | : |
909 | : [gp] "rm" (gp) |
910 | : "r8" , |
911 | "r9" , |
912 | "r10" , |
913 | "r11" , |
914 | "r13" , |
915 | "r14" , |
916 | "rax" , |
917 | "rcx" , |
918 | "rsi" , |
919 | "rdi" , |
920 | "rbx" , |
921 | "r12" , |
922 | "r15" , |
923 | "memory" ); |
924 | } |
925 | void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { |
926 | asm volatile( |
927 | #if FBGEMM_USE_CLANG_INTEL_SYNTAX_ASM_HACK |
928 | "mov %[gp], %%r14\t\n" |
929 | ".intel_syntax noprefix\t\n" |
930 | #else |
931 | "mov r14, %[gp]\t\n" |
932 | #endif |
933 | |
934 | // Copy parameters |
935 | // k |
936 | "mov r8, [r14 + 0]\t\n" |
937 | "dec r8\t\n" |
938 | // A |
939 | "mov r9, [r14 + 8]\t\n" |
940 | // B |
941 | "mov r10, [r14 + 16]\t\n" |
942 | // beta |
943 | "lea r15, [r14 + 24]\t\n" |
944 | // C |
945 | "mov r12, [r14 + 32]\t\n" |
946 | // ldc |
947 | "mov r13, [r14 + 40]\t\n" |
948 | // b_block_cols |
949 | "mov rdi, [r14 + 48]\t\n" |
950 | // b_block_size |
951 | "mov rsi, [r14 + 56]\t\n" |
952 | |
953 | // Make copies of A and C |
954 | "mov rax, r9\t\n" |
955 | "mov rcx, r12\t\n" |
956 | |
957 | "xor ebx, ebx\t\n" |
958 | "loop_outter%=:\t\n" |
959 | "mov r14, r8\t\n" |
960 | "vbroadcastss ymm15,DWORD PTR [r15]\t\n" |
961 | "vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n" |
962 | "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" |
963 | "vxorps xmm0, xmm0, xmm0\t\n" |
964 | "vcomiss xmm15, xmm0\t\n" |
965 | "jz zero_regs%=\t\n" |
966 | |
967 | // Setup values with beta multiplication |
968 | "vmulps ymm0, ymm15, [r12 + 0]\t\n" |
969 | "vmulps ymm1, ymm15, [r12 + 32]\t\n" |
970 | "add r12, r13\t\n" |
971 | "vmulps ymm2, ymm15, [r12 + 0]\t\n" |
972 | "vmulps ymm3, ymm15, [r12 + 32]\t\n" |
973 | "add r12, r13\t\n" |
974 | "vmulps ymm4, ymm15, [r12 + 0]\t\n" |
975 | "vmulps ymm5, ymm15, [r12 + 32]\t\n" |
976 | "add r12, r13\t\n" |
977 | "vmulps ymm6, ymm15, [r12 + 0]\t\n" |
978 | "vmulps ymm7, ymm15, [r12 + 32]\t\n" |
979 | "add r12, r13\t\n" |
980 | "vmulps ymm8, ymm15, [r12 + 0]\t\n" |
981 | "vmulps ymm9, ymm15, [r12 + 32]\t\n" |
982 | "add r12, r13\t\n" |
983 | "vmulps ymm10, ymm15, [r12 + 0]\t\n" |
984 | "vmulps ymm11, ymm15, [r12 + 32]\t\n" |
985 | "test r14,r14\t\n" |
986 | "jz skip_preload%=\t\n" |
987 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
988 | "skip_preload%=:\t\n" |
989 | "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" |
990 | "vfmadd231ps ymm0,ymm13,ymm12\t\n" |
991 | "vfmadd231ps ymm1,ymm14,ymm12\t\n" |
992 | "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" |
993 | "vfmadd231ps ymm2,ymm13,ymm12\t\n" |
994 | "vfmadd231ps ymm3,ymm14,ymm12\t\n" |
995 | "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" |
996 | "vfmadd231ps ymm4,ymm13,ymm12\t\n" |
997 | "vfmadd231ps ymm5,ymm14,ymm12\t\n" |
998 | "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" |
999 | "vfmadd231ps ymm6,ymm13,ymm12\t\n" |
1000 | "vfmadd231ps ymm7,ymm14,ymm12\t\n" |
1001 | "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" |
1002 | "vfmadd231ps ymm8,ymm13,ymm12\t\n" |
1003 | "vfmadd231ps ymm9,ymm14,ymm12\t\n" |
1004 | "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" |
1005 | "vfmadd231ps ymm10,ymm13,ymm12\t\n" |
1006 | "vfmadd231ps ymm11,ymm14,ymm12\t\n" |
1007 | "mov r12, rcx\t\n" |
1008 | "test r14,r14\t\n" |
1009 | "jnz next_inner%=\t\n" |
1010 | "add r10,32\t\n" |
1011 | "jmp dump_C%=\t\n" |
1012 | |
1013 | "zero_regs%=:\t\n" |
1014 | |
1015 | "test r14,r14\t\n" |
1016 | "jz skip_preload_b_zero%=\t\n" |
1017 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
1018 | "skip_preload_b_zero%=:\t\n" |
1019 | "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" |
1020 | "vmulps ymm0,ymm13,ymm12\t\n" |
1021 | "vmulps ymm1,ymm14,ymm12\t\n" |
1022 | "add r12, r13\t\n" |
1023 | "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" |
1024 | "vmulps ymm2,ymm13,ymm12\t\n" |
1025 | "vmulps ymm3,ymm14,ymm12\t\n" |
1026 | "add r12, r13\t\n" |
1027 | "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" |
1028 | "vmulps ymm4,ymm13,ymm12\t\n" |
1029 | "vmulps ymm5,ymm14,ymm12\t\n" |
1030 | "add r12, r13\t\n" |
1031 | "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" |
1032 | "vmulps ymm6,ymm13,ymm12\t\n" |
1033 | "vmulps ymm7,ymm14,ymm12\t\n" |
1034 | "add r12, r13\t\n" |
1035 | "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" |
1036 | "vmulps ymm8,ymm13,ymm12\t\n" |
1037 | "vmulps ymm9,ymm14,ymm12\t\n" |
1038 | "add r12, r13\t\n" |
1039 | "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" |
1040 | "vmulps ymm10,ymm13,ymm12\t\n" |
1041 | "vmulps ymm11,ymm14,ymm12\t\n" |
1042 | "mov r12, rcx\t\n" |
1043 | "test r14,r14\t\n" |
1044 | "jnz next_inner%=\t\n" |
1045 | "add r10,32\t\n" |
1046 | "jmp dump_C%=\t\n" |
1047 | |
1048 | "loop_inner%=:\t\n" |
1049 | |
1050 | "vmovaps ymm13,ymm15\t\n" |
1051 | "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" |
1052 | "vcvtph2ps ymm15,XMMWORD PTR [r10 + 32]\t\n" |
1053 | "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" |
1054 | "vfmadd231ps ymm0,ymm13,ymm12\t\n" |
1055 | "vfmadd231ps ymm1,ymm14,ymm12\t\n" |
1056 | "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" |
1057 | "vfmadd231ps ymm2,ymm13,ymm12\t\n" |
1058 | "vfmadd231ps ymm3,ymm14,ymm12\t\n" |
1059 | "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" |
1060 | "vfmadd231ps ymm4,ymm13,ymm12\t\n" |
1061 | "vfmadd231ps ymm5,ymm14,ymm12\t\n" |
1062 | "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" |
1063 | "vfmadd231ps ymm6,ymm13,ymm12\t\n" |
1064 | "vfmadd231ps ymm7,ymm14,ymm12\t\n" |
1065 | "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" |
1066 | "vfmadd231ps ymm8,ymm13,ymm12\t\n" |
1067 | "vfmadd231ps ymm9,ymm14,ymm12\t\n" |
1068 | "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" |
1069 | "vfmadd231ps ymm10,ymm13,ymm12\t\n" |
1070 | "vfmadd231ps ymm11,ymm14,ymm12\t\n" |
1071 | |
1072 | "next_inner%=:\t\n" |
1073 | "add r9,24\t\n" |
1074 | "add r10,32\t\n" |
1075 | "dec r14\t\n" |
1076 | "jnz loop_inner%=\t\n" |
1077 | |
1078 | "vmovaps ymm13,ymm15\t\n" |
1079 | "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" |
1080 | "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" |
1081 | "vfmadd231ps ymm0,ymm13,ymm12\t\n" |
1082 | "vfmadd231ps ymm1,ymm14,ymm12\t\n" |
1083 | "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" |
1084 | "vfmadd231ps ymm2,ymm13,ymm12\t\n" |
1085 | "vfmadd231ps ymm3,ymm14,ymm12\t\n" |
1086 | "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" |
1087 | "vfmadd231ps ymm4,ymm13,ymm12\t\n" |
1088 | "vfmadd231ps ymm5,ymm14,ymm12\t\n" |
1089 | "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" |
1090 | "vfmadd231ps ymm6,ymm13,ymm12\t\n" |
1091 | "vfmadd231ps ymm7,ymm14,ymm12\t\n" |
1092 | "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" |
1093 | "vfmadd231ps ymm8,ymm13,ymm12\t\n" |
1094 | "vfmadd231ps ymm9,ymm14,ymm12\t\n" |
1095 | "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" |
1096 | "vfmadd231ps ymm10,ymm13,ymm12\t\n" |
1097 | "vfmadd231ps ymm11,ymm14,ymm12\t\n" |
1098 | "add r9,24\t\n" |
1099 | "add r10,32\t\n" |
1100 | // Dump C |
1101 | "dump_C%=:\t\n" |
1102 | "vmovups ymmword PTR [r12 + 0], ymm0\t\n" |
1103 | "vmovups ymmword PTR [r12 + 32], ymm1\t\n" |
1104 | "add r12, r13\t\n" |
1105 | "vmovups ymmword PTR [r12 + 0], ymm2\t\n" |
1106 | "vmovups ymmword PTR [r12 + 32], ymm3\t\n" |
1107 | "add r12, r13\t\n" |
1108 | "vmovups ymmword PTR [r12 + 0], ymm4\t\n" |
1109 | "vmovups ymmword PTR [r12 + 32], ymm5\t\n" |
1110 | "add r12, r13\t\n" |
1111 | "vmovups ymmword PTR [r12 + 0], ymm6\t\n" |
1112 | "vmovups ymmword PTR [r12 + 32], ymm7\t\n" |
1113 | "add r12, r13\t\n" |
1114 | "vmovups ymmword PTR [r12 + 0], ymm8\t\n" |
1115 | "vmovups ymmword PTR [r12 + 32], ymm9\t\n" |
1116 | "add r12, r13\t\n" |
1117 | "vmovups ymmword PTR [r12 + 0], ymm10\t\n" |
1118 | "vmovups ymmword PTR [r12 + 32], ymm11\t\n" |
1119 | |
1120 | // next outer iteration |
1121 | "add rcx, 64\t\n" |
1122 | "mov r12, rcx\t\n" |
1123 | "mov r9, rax\t\n" |
1124 | "inc rbx\t\n" |
1125 | "cmp rbx, rdi\t\n" |
1126 | "jl loop_outter%=\t\n" |
1127 | : |
1128 | : [gp] "rm" (gp) |
1129 | : "r8" , |
1130 | "r9" , |
1131 | "r10" , |
1132 | "r11" , |
1133 | "r13" , |
1134 | "r14" , |
1135 | "rax" , |
1136 | "rcx" , |
1137 | "rsi" , |
1138 | "rdi" , |
1139 | "rbx" , |
1140 | "r12" , |
1141 | "r15" , |
1142 | "memory" ); |
1143 | } |
1144 | |
1145 | } // namespace fbgemm |
1146 | |