1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8#include <asmjit/asmjit.h>
9#include <cpuinfo.h>
10#include <cassert>
11#include <cstdint>
12#include <map>
13#include <mutex>
14#include <sstream>
15#include <string>
16#include <tuple>
17#include <type_traits>
18#include "./CodeCache.h"
19#include "fbgemm/ConvUtils.h"
20#include "fbgemm/Fbgemm.h"
21#include "fbgemm/Utils.h"
22/*#define FBGEMM_LOG_CODE 1*/
23
24namespace fbgemm {
25
26namespace x86 = asmjit::x86;
27
28/**
29 * @brief Generate instructions for initializing the C registers to 0.
30 */
31void initCRegs(x86::Emitter* a, int rowRegs, int colRegs);
32
33template <typename TA, typename TB, typename TC, typename accT>
34class DirectConvCodeGenBase {
35 public:
36 using jit_micro_kernel_fp = void (*)(
37 const TA* bufferA,
38 const TB* bufferB,
39 const TB* b_pf,
40 TC* bufferC,
41 int kc,
42 int ldc);
43
44 // microkernel signature for transposed direct conv
45 // ic: input channel
46 // ldcReg: leading dimension of output, a.k.a OC
47 // o1Xoc: output width multiply output channel:
48 // OUT_DIM[1] x OC
49 using jit_micro_kernel_fp_convT = void (*)(
50 const TA* bufferA,
51 const TB* bufferB,
52 TC* bufferC,
53 int ic,
54 int ldcReg,
55 int o1Xoc,
56 int i1);
57
58 static std::mutex rtMutex_; ///< Control access to runtime;
59
60 // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr
61 static CodeCache<
62 std::tuple<bool, int, int, int, int, int, int>,
63 jit_micro_kernel_fp>
64 codeCache_; ///< JIT Code Cache for reuse.
65
66 // The hash depends on accumulate, stride, mr, nr
67 static CodeCache<
68 std::tuple<bool, int, int, int>,
69 jit_micro_kernel_fp_convT>
70 codeCacheT_; ///< JIT Code Cache for reuse.
71
72 /**
73 * @brief Generate instructions for storing the C registers back to the
74 * memory.
75 */
76 template <inst_set_t instSet>
77 void storeCRegs(
78 x86::Emitter* a,
79 int rowRegs,
80 int colRegs,
81 x86::Gp C_Offset,
82 x86::Gp ldcReg,
83 bool accum);
84
85 /**
86 * @brief Generate instructions for storing the C registers back to the
87 * memory.
88 */
89 template <inst_set_t instSet>
90 void storeCRegsTrans(
91 x86::Emitter* a,
92 int rowRegs,
93 int colRegs,
94 x86::Gp C_offset,
95 x86::Gp o1XocReg,
96 x86::Gp ldcReg,
97 bool accum);
98
99 /**
100 * @brief Generate filename to dump generated code
101 * (debug-only)
102 */
103 template <inst_set_t instSet>
104 static std::string getCodeLoggingFile(
105 bool accum,
106 int mc,
107 int nc,
108 int NCB,
109 int KCB,
110 int MR,
111 int NR) {
112 std::ostringstream oss;
113 oss << "directconv_";
114 if (std::is_same<accT, std::int16_t>::value) {
115 oss << "acc16_";
116 } else if (std::is_same<accT, std::int32_t>::value) {
117 oss << "acc32_";
118 } else {
119 oss << "unknown_";
120 }
121 oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc)
122 << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB)
123 << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR)
124 << "_NR-" + std::to_string(NR);
125 if (instSet == inst_set_t::avx512_vnni) {
126 oss << "_avx512vnni";
127 } else if (instSet == inst_set_t::avx512) {
128 oss << "_avx512";
129 } else if (instSet == inst_set_t::avx512_ymm) {
130 oss << "_avx512_ymm";
131 } else if (instSet == inst_set_t::avx2) {
132 oss << "_avx2";
133 }
134 oss << ".txt";
135 return oss.str();
136 }
137
138 /**
139 * @brief Get or Create the instructions for macro-kernel.
140 *
141 * If the problem size (mc, nc) and accumulation flag (accum) can be found in
142 * the code cache (a hash map), then get the macro-kernel instructions
143 * directly from it. Otherwise, create the instructions for macro-kernel, and
144 * store that into the code cache.
145 */
146 template <inst_set_t instSet>
147 jit_micro_kernel_fp
148 getOrCreateDirectConv(bool accum, int32_t mc, int32_t nc, int32_t kc);
149
150 /**
151 * @brief Get or Create the instructions for macro-kernel.
152 *
153 * If the problem size (mc, nc) and accumulation flag (accum) can be found in
154 * the code cache (a hash map), then get the macro-kernel instructions
155 * directly from it. Otherwise, create the instructions for macro-kernel, and
156 * store that into the code cache.
157 */
158 template <inst_set_t instSet>
159 jit_micro_kernel_fp_convT
160 getOrCreateDirectConvTrans(bool accum, int32_t stride, int32_t numColRegs);
161
162 /**
163 * @brief Generate instructions for computing block in the rank-k update.
164 */
165 template <inst_set_t instSet>
166 void genComputeBlock(
167 x86::Emitter* a,
168 x86::Gp buffer_A,
169 x86::Gp buffer_B,
170 x86::Gp B_pf,
171 int rowRegs,
172 int colRegs,
173 int lda);
174 /**
175 * @brief Generate instructions for computing block in the rank-k update.
176 */
177 template <inst_set_t instSet>
178 void genComputeBlockDirectConv(
179 x86::Emitter* a,
180 x86::Gp buffer_A,
181 x86::Gp buffer_B,
182 x86::Gp B_pf,
183 int rowRegs,
184 int colRegs,
185 int strideXich);
186
187 /**
188 * @brief Generate instructions for computing block in the rank-k update.
189 */
190 template <inst_set_t instSet>
191 void genComputeBlockDirectConvTrans(
192 x86::Emitter* a,
193 x86::Gp buffer_A,
194 x86::Gp buffer_B,
195 x86::Gp icReg,
196 x86::Gp C_offset,
197 int rowRegs,
198 int colRegs);
199
200 private:
201 static asmjit::JitRuntime& runtime() {
202 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
203 // depents on other static
204 // variables. Required to prevent
205 // initialization order fiasco
206 return rt;
207 }
208};
209
210template <typename TA, typename TB, typename TC, typename accT>
211std::mutex DirectConvCodeGenBase<TA, TB, TC, accT>::rtMutex_;
212
213template <typename TA, typename TB, typename TC, typename accT>
214CodeCache<
215 std::tuple<bool, int, int, int, int, int, int>,
216 typename DirectConvCodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
217 DirectConvCodeGenBase<TA, TB, TC, accT>::codeCache_;
218
219template <typename TA, typename TB, typename TC, typename accT>
220CodeCache<
221 std::tuple<bool, int, int, int>,
222 typename DirectConvCodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp_convT>
223 DirectConvCodeGenBase<TA, TB, TC, accT>::codeCacheT_;
224
225}; // namespace fbgemm
226