1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#include "./GenerateI8Depthwise.h"
8
9#include <asmjit/asmjit.h>
10#include <cassert>
11#include <iostream>
12#include <numeric>
13
14#include "./CodeCache.h"
15#include "./CodeGenHelpers.h"
16#include "fbgemm/Utils.h"
17
18namespace fbgemm {
19
20namespace {
21asmjit::JitRuntime& runtime() {
22 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
23 // depents on other static
24 // variables. Required to prevent
25 // initialization order fiasco
26 return rt;
27}
28
29// Controll access to runtime;
30std::mutex rtMutex_;
31
32// The hash depends on D, K_T, K_H, K_W, oc_per_g, compute_a_sum,
33// remainder, prev_skip, next_skip, top_skip, bottom_skip, left_skip, and
34// right_skip.
35CodeCache<
36 std::
37 tuple<int, int, int, int, int, bool, int, int, int, int, int, int, int>,
38 GenI8Depthwise::jit_kernel_signature>
39 codeCache_;
40} // namespace
41
42namespace x86 = asmjit::x86;
43
44// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
45// A is in uint8_t
46// B is in int8_t and pre-interleaved
47// C is in int32_t and 4 registers have results in the following layout:
48// c0_v: c[0:4], c[16:20]
49// c1_v: c[4:8], c[20:24]
50// c2_v: c[8:12], c[24:28]
51// c3_v: c[12:16], c[28:32]
52static void genMaddEpi16xNPacked(
53 x86::Emitter* e,
54 x86::Ymm a[4],
55 x86::Gp b,
56 x86::Ymm c[4],
57 x86::Ymm* a_sum,
58 int n,
59 int remainder,
60 bool accumulation,
61 x86::Ymm one_epi8,
62 x86::Ymm one_epi16,
63 x86::Ymm zero) {
64 // Interleave inputs corresponding to 4 filter positions.
65 // Reuse a[1] and a[3] to save registers
66 x86::Ymm a01_lo(0), a01_hi(1), a23_lo(a[1]), a23_hi(a[3]);
67 e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]);
68 if (remainder >= 8) {
69 e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]);
70 }
71 if (n > 2) {
72 e->vpunpcklbw(a23_lo, a[2], n == 3 ? zero : a[3]);
73 if (remainder >= 8) {
74 e->vpunpckhbw(a23_hi, a[2], n == 3 ? zero : a[3]);
75 }
76 }
77
78 // Compute row_wise sum of A for row_offsets
79 if (a_sum) {
80 if (accumulation) {
81 e->vpmaddubsw(a[0], a01_lo, one_epi8);
82 e->vpaddsw(a_sum[0], a[0], a_sum[0]);
83
84 if (remainder >= 8) {
85 e->vpmaddubsw(a[2], a01_hi, one_epi8);
86 e->vpaddsw(a_sum[1], a[2], a_sum[1]);
87 }
88 } else {
89 e->vpmaddubsw(a_sum[0], a01_lo, one_epi8);
90 if (remainder >= 8) {
91 e->vpmaddubsw(a_sum[1], a01_hi, one_epi8);
92 }
93 }
94
95 if (n > 2) {
96 e->vpmaddubsw(a[0], a23_lo, one_epi8);
97 e->vpaddsw(a_sum[0], a[0], a_sum[0]);
98
99 if (remainder >= 8) {
100 e->vpmaddubsw(a[2], a23_hi, one_epi8);
101 e->vpaddsw(a_sum[1], a[2], a_sum[1]);
102 }
103 }
104 }
105
106 if (n > 2) {
107 // Reusing a
108 e->vpunpcklwd(a[0], a01_lo, a23_lo);
109 e->vpunpckhwd(a[1], a01_lo, a23_lo);
110 if (remainder >= 16) {
111 e->vpunpcklwd(a[2], a01_hi, a23_hi);
112 e->vpunpckhwd(a[3], a01_hi, a23_hi);
113 }
114
115 e->vpmaddubsw(a[0], a[0], x86::ymmword_ptr(b));
116 e->vpmaddubsw(a[1], a[1], x86::ymmword_ptr(b, 32));
117 if (remainder >= 16) {
118 e->vpmaddubsw(a[2], a[2], x86::ymmword_ptr(b, 64));
119 e->vpmaddubsw(a[3], a[3], x86::ymmword_ptr(b, 96));
120 }
121
122 if (accumulation) {
123 e->vpmaddwd(a[0], a[0], one_epi16);
124 e->vpaddd(c[0], c[0], a[0]);
125 e->vpmaddwd(a[1], a[1], one_epi16);
126 e->vpaddd(c[1], c[1], a[1]);
127
128 if (remainder >= 16) {
129 e->vpmaddwd(a[2], a[2], one_epi16);
130 e->vpaddd(c[2], c[2], a[2]);
131 e->vpmaddwd(a[3], a[3], one_epi16);
132 e->vpaddd(c[3], c[3], a[3]);
133 }
134 } else {
135 e->vpmaddwd(c[0], a[0], one_epi16);
136 e->vpmaddwd(c[1], a[1], one_epi16);
137
138 if (remainder >= 16) {
139 e->vpmaddwd(c[2], a[2], one_epi16);
140 e->vpmaddwd(c[3], a[3], one_epi16);
141 }
142 }
143 } else {
144 // Reusing a
145 e->vpmaddubsw(a[0], a01_lo, x86::ymmword_ptr(b));
146 e->vpmaddubsw(a[1], a01_hi, x86::ymmword_ptr(b, 32));
147
148 if (accumulation) {
149 e->vpmovsxwd(a[2], a[0].half());
150 e->vpaddd(c[0], c[0], a[2]);
151 e->vpmovsxwd(a[3], a[1].half());
152 e->vpaddd(c[1], c[1], a[3]);
153
154 if (remainder >= 16) {
155 e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
156 e->vpmovsxwd(a[0], a[0].half());
157 e->vpaddd(c[2], c[2], a[0]);
158 e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
159 e->vpmovsxwd(a[1], a[1].half());
160 e->vpaddd(c[3], c[3], a[1]);
161 }
162 } else {
163 e->vpmovsxwd(c[0], a[0].half());
164 e->vpmovsxwd(c[1], a[1].half());
165
166 if (remainder >= 16) {
167 e->vextracti128(a[0].half(), a[0], asmjit::Imm(1));
168 e->vpmovsxwd(c[2], a[0].half());
169 e->vextracti128(a[1].half(), a[1], asmjit::Imm(1));
170 e->vpmovsxwd(c[3], a[1].half());
171 }
172 }
173 }
174}
175
176GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
177 int D,
178 std::array<int, 3> F,
179 int oc_per_g,
180 bool compute_a_sum,
181 int remainder,
182 int prev_skip,
183 int next_skip,
184 int top_skip,
185 int bottom_skip,
186 int left_skip,
187 int right_skip) {
188 std::tuple<int, int, int, int, int, bool, int, int, int, int, int, int, int>
189 kernelSig = std::make_tuple(
190 D,
191 F[0],
192 F[1],
193 F[2],
194 oc_per_g,
195 compute_a_sum,
196 remainder,
197 prev_skip,
198 next_skip,
199 top_skip,
200 bottom_skip,
201 left_skip,
202 right_skip);
203
204 return codeCache_.getOrCreate(kernelSig, [&]() -> jit_kernel_signature {
205 asmjit::CodeHolder code;
206 code.init(runtime().environment());
207 x86::Assembler assembler(&code);
208 x86::Emitter* e = assembler.as<x86::Emitter>();
209#ifdef FBGEMM_LOG_CODE
210 std::string filename = "dwconv_" + std::to_string(D) + "d_";
211 for (int i = 3 - D; i < 3; ++i) {
212 filename += std::to_string(K[i]);
213 if (i < 2) {
214 filename += "x"
215 }
216 }
217 filename += "_" + std::to_string(oc_per_g);
218 if (compute_a_sum) {
219 filename += "_asum";
220 }
221 if (remainder) {
222 filename += "_remainder" + std::to_string(remainder);
223 }
224 if (prev_skip) {
225 filename += "_prev_skip" + std::to_string(prev_skip);
226 }
227 if (next_skip) {
228 filename += "_next_skip" + std::to_string(next_skip);
229 }
230 if (top_skip) {
231 filename += "_top_skip" + std::to_string(top_skip);
232 }
233 if (bottom_skip) {
234 filename += "_bottom_skip" + std::to_string(bottom_skip);
235 }
236 if (left_skip) {
237 filename += "_left_skip" + std::to_string(left_skip);
238 }
239 if (right_skip) {
240 filename += "_right_skip" + std::to_string(right_skip);
241 }
242 filename += ".txt";
243 FILE* codeLogFile = fopen(filename.c_str(), "w");
244 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
245 code.setLogger(codeLogger);
246#endif
247
248 x86::Gp a_addr = e->zdi();
249 x86::Gp b_addr = e->zsi();
250 x86::Gp c_addr = e->zdx();
251 x86::Gp a_sum_addr = e->zcx();
252 x86::Gp h = e->gpz(8);
253 x86::Gp w = e->gpz(9);
254 x86::Gp ic = e->gpz(10);
255 x86::Gp mask_addr = e->gpz(11);
256 x86::Gp a_zero_point = e->gpz(12);
257 x86::Gp b_zero_point_addr = e->gpz(13);
258 x86::Gp ic_loop_count = e->gpz(14);
259 x86::Gp a_addr_save = e->gpz(15);
260
261 asmjit::FuncDetail func;
262 func.init(
263 asmjit::FuncSignatureT<
264 void,
265 const std::uint8_t*,
266 const std::int8_t*,
267 std::int32_t*,
268 std::int32_t*,
269 int,
270 int,
271 int,
272 const int*,
273 int,
274 const std::int32_t*>(asmjit::CallConvId::kHost),
275 e->environment());
276
277 asmjit::FuncFrame frame;
278 frame.init(func);
279
280 frame.setDirtyRegs(
281 asmjit::RegGroup::kVec,
282 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
283 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
284 frame.setDirtyRegs(
285 asmjit::RegGroup::kGp,
286 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
287
288 asmjit::FuncArgsAssignment args(&func);
289 args.assignAll(
290 a_addr,
291 b_addr,
292 c_addr,
293 a_sum_addr,
294 h,
295 w,
296 ic,
297 mask_addr,
298 a_zero_point,
299 b_zero_point_addr);
300
301 args.updateFuncFrame(frame);
302 frame.finalize();
303
304 e->emitProlog(frame);
305 e->emitArgsAssignment(frame, args);
306
307 // Assign vector registers
308 x86::Ymm a[4];
309 x86::Ymm c[4];
310 x86::Ymm a_sum[2];
311
312 int vreg_id = 2; // reserve 2 for temp vreg
313 for (int i = 0; i < 4; ++i, ++vreg_id) {
314 a[i] = x86::Ymm(vreg_id);
315 }
316 for (int i = 0; i < 4; ++i, ++vreg_id) {
317 c[i] = x86::Ymm(vreg_id);
318 }
319 if (compute_a_sum) {
320 a_sum[0] = x86::Ymm(vreg_id);
321 ++vreg_id;
322 a_sum[1] = x86::Ymm(vreg_id);
323 ++vreg_id;
324 }
325 x86::Ymm mask_vreg(vreg_id);
326 constexpr int vlen = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
327 if (remainder != simd_info<inst_set_t::avx2>::WIDTH_BYTES) {
328 ++vreg_id;
329 e->vmovups(
330 mask_vreg,
331 x86::ymmword_ptr(
332 mask_addr,
333 (vlen - remainder / 4 / oc_per_g) % vlen * sizeof(int32_t)));
334 }
335 x86::Ymm one_epi8(vreg_id);
336 if (compute_a_sum) {
337 ++vreg_id;
338 gen8BitVectorOne(e, one_epi8);
339 }
340
341 int K = std::accumulate(F.begin(), F.end(), 1, std::multiplies<int>());
342 x86::Ymm one_epi16(vreg_id);
343 if (K > 2) {
344 ++vreg_id;
345 gen16BitVectorOne<inst_set_t::avx2, x86::Ymm>(e, one_epi16);
346 }
347
348 bool has_pad = prev_skip || next_skip || top_skip || bottom_skip ||
349 left_skip || right_skip;
350 bool need_zero = K % 4 == 3 || K % 4 == 1;
351 // When out of registers, zero and A_zero_point_vreg need to share.
352 bool recompute_zero = vreg_id == 15 && need_zero;
353
354 x86::Ymm a_zero_point_vreg(vreg_id);
355 if (!recompute_zero && has_pad) {
356 e->movq(a_zero_point_vreg.half(), a_zero_point);
357 e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half());
358 }
359 if (vreg_id < 15) {
360 ++vreg_id;
361 }
362 x86::Ymm zero(vreg_id);
363 if (need_zero && (!recompute_zero || !has_pad)) {
364 e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
365 }
366
367 // Assign scalar registers
368 e->imul(w, ic);
369 e->imul(h, w);
370 if (D >= 3) {
371 e->mov(a_addr_save, w);
372 e->imul(a_addr_save, F[1]);
373 e->sub(h, a_addr_save); // h * w * ic - F[1] * w * ic
374 }
375 e->mov(a_addr_save, ic);
376 e->imul(a_addr_save, F[2]);
377 e->sub(w, a_addr_save); // w * ic - F[2] * ic
378
379 e->mov(ic_loop_count, ic);
380 e->add(ic_loop_count, asmjit::Imm(32 / oc_per_g - 1));
381 e->sar(ic_loop_count, asmjit::Imm(oc_per_g == 1 ? 5 : 4));
382
383 e->mov(a_addr_save, a_addr);
384 asmjit::Label ic_loop_begin = e->newLabel(), ic_loop_end = e->newLabel();
385
386 // main_loop == false: the last vector iteration across input channels
387 for (bool main_loop : {true, false}) {
388 if (main_loop) {
389 e->bind(ic_loop_begin);
390 e->dec(ic_loop_count);
391 e->jle(ic_loop_end);
392 }
393
394 if (recompute_zero && has_pad) {
395 e->movq(a_zero_point_vreg.half(), a_zero_point);
396 e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half());
397 }
398
399 int i = 0;
400 // Iterate across the reduction (filter) dimension
401 for (int f_t = 0; f_t < ((D == 2) ? 1 : F[0]); ++f_t) {
402 for (int f_h = 0; f_h < F[1]; ++f_h) {
403 for (int f_w = 0; f_w < F[2]; ++f_w, ++i) {
404 bool pad = false;
405 if (D > 2) {
406 if (f_t < prev_skip || f_t >= F[0] - next_skip) {
407 pad = true;
408 }
409 }
410 if (f_h < top_skip || f_h >= F[1] - bottom_skip ||
411 f_w < left_skip || f_w >= F[2] - right_skip) {
412 pad = true;
413 }
414
415 // Load A
416 if (pad) {
417 e->vmovups(a[i % 4], a_zero_point_vreg);
418 } else {
419 if (oc_per_g == 1) {
420 if (!main_loop && remainder != 32) {
421 e->vmaskmovps(a[i % 4], mask_vreg, x86::ymmword_ptr(a_addr));
422 } else {
423 e->vmovups(a[i % 4], x86::ymmword_ptr(a_addr));
424 }
425 } else {
426 assert(oc_per_g == 2);
427 if (!main_loop && remainder != 32) {
428 e->vmaskmovps(
429 a[i % 4].half(),
430 mask_vreg.half(),
431 x86::xmmword_ptr(a_addr));
432 } else {
433 e->vmovups(a[i % 4].half(), x86::xmmword_ptr(a_addr));
434 }
435 // Duplicate each byte.
436 e->vpmovzxbw(a[i % 4], a[i % 4].half());
437 e->vpsllw(x86::ymm(i % 2), a[i % 4], asmjit::Imm(8));
438 e->vpaddw(a[i % 4], a[i % 4], x86::ymm(i % 2));
439 }
440 }
441
442 // Compute when we have 4 inputs or this is the last iteration
443 if (i % 4 == 3 || i == K - 1) {
444 if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) {
445 if (recompute_zero && has_pad) {
446 e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
447 }
448 }
449
450 genMaddEpi16xNPacked(
451 e,
452 a,
453 b_addr,
454 c,
455 compute_a_sum ? a_sum : nullptr,
456 /*n=*/std::min(K - i / 4 * 4, 4),
457 main_loop ? 32 : remainder,
458 /*accumulation=*/i / 4 > 0,
459 one_epi8,
460 one_epi16,
461 zero);
462
463 if (i != K - 1) {
464 e->add(b_addr, asmjit::Imm(32 * 4));
465 } else if (main_loop) {
466 e->add(b_addr, asmjit::Imm(32 * (K - i / 4 * 4 + 1) / 2 * 2));
467 }
468
469 if (K - i / 4 * 4 >= 3 && K - i / 4 * 4 <= 6) {
470 for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
471 // fix? output layout (see genMaddEpi16xNPacked for details)
472 e->vperm2f128(
473 a[r],
474 c[r % 2 * 2],
475 c[r % 2 * 2 + 1],
476 asmjit::Imm(r < 2 ? 0x20 : 0x31));
477 }
478 for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
479 e->vmovdqa(c[r], a[r]);
480 }
481 }
482 }
483 if (i != K - 1) {
484 e->add(a_addr, ic); // advance to next pixel
485 }
486 }
487 if (i != K - 1) {
488 e->add(a_addr, w); // advance to next row
489 }
490 }
491 if (D >= 3 && i != K - 1) {
492 e->add(a_addr, h); // advance to next frame
493 }
494 }
495
496 for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
497 e->vmovups(x86::ymmword_ptr(c_addr, r * 32), c[r]);
498 }
499
500 if (compute_a_sum) {
501 if (oc_per_g == 1) {
502 e->vpmovsxwd(a[0], a_sum[0].half());
503 e->vmovups(x86::ymmword_ptr(a_sum_addr), a[0]);
504 } else {
505 // Rollback duplication
506 e->vpsrld(a_sum[0], a_sum[0], asmjit::Imm(16));
507 e->vmovups(x86::xmmword_ptr(a_sum_addr), a_sum[0].half());
508 }
509
510 if (main_loop || remainder >= 8) {
511 if (oc_per_g == 1) {
512 e->vpmovsxwd(a[1], a_sum[1].half());
513 e->vmovups(x86::ymmword_ptr(a_sum_addr, 32), a[1]);
514 } else {
515 // Rollback duplication
516 e->vpsrld(a_sum[1], a_sum[1], asmjit::Imm(16));
517 e->vmovups(x86::xmmword_ptr(a_sum_addr, 16), a_sum[1].half());
518 }
519 }
520
521 if (main_loop || remainder >= 16) {
522 e->vextracti128(a_sum[0].half(), a_sum[0], asmjit::Imm(1));
523 if (oc_per_g == 1) {
524 e->vpmovsxwd(a_sum[0], a_sum[0].half());
525 e->vmovups(x86::ymmword_ptr(a_sum_addr, 64), a_sum[0]);
526 } else {
527 e->vmovups(x86::xmmword_ptr(a_sum_addr, 32), a_sum[0].half());
528 }
529 }
530
531 if (main_loop || remainder >= 24) {
532 e->vextracti128(a_sum[1].half(), a_sum[1], asmjit::Imm(1));
533 if (oc_per_g == 1) {
534 e->vpmovsxwd(a_sum[1], a_sum[1].half());
535 e->vmovups(x86::ymmword_ptr(a_sum_addr, 96), a_sum[1]);
536 } else {
537 e->vmovups(x86::xmmword_ptr(a_sum_addr, 48), a_sum[1].half());
538 }
539 }
540
541 if (main_loop) {
542 e->add(a_sum_addr, asmjit::Imm(128 / oc_per_g));
543 }
544 }
545
546 if (main_loop) {
547 e->add(c_addr, asmjit::Imm(128));
548 e->add(a_addr_save, asmjit::Imm(32 / oc_per_g));
549 e->mov(a_addr, a_addr_save);
550 e->jmp(ic_loop_begin);
551
552 e->bind(ic_loop_end);
553 }
554 }
555
556 e->emitEpilog(frame);
557
558 jit_kernel_signature fn;
559 asmjit::Error err;
560 {
561 std::unique_lock<std::mutex> lock(rtMutex_);
562 err = runtime().add(&fn, &code);
563 }
564 if (err) {
565 std::cout << "Error: in fn add" << std::endl;
566 return nullptr;
567 }
568
569#ifdef FBGEMM_LOG_CODE
570 fclose(codeLogFile);
571 delete codeLogger;
572#endif
573
574 return fn;
575 });
576}
577
578} // namespace fbgemm
579