1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "libjit_defs.h"
17
18namespace {
19
20/// Macros for accessing submatrices of a matmul using the leading dimension.
21#define A(i, j) a[(j)*lda + (i)]
22#define B(i, j) b[(j)*ldb + (i)]
23#define C(i, j) c[(j)*ldc + (i)]
24
25/// Naive gemm helper to handle oddly-sized matrices.
26void libjit_matmul_odd(int m, int n, int k, const float *a, int lda,
27 const float *b, int ldb, float *c, int ldc) {
28 // The order of these loops is tuned for column-major matrices.
29 for (int p = 0; p < k; p++) {
30 for (int j = 0; j < n; j++) {
31 for (int i = 0; i < m; i++) {
32 C(i, j) += A(i, p) * B(p, j);
33 }
34 }
35 }
36}
37
38/// Number of registers to use for rows of A in the dot-product kernel.
39constexpr int regsA = 4;
40/// Number of registers to use for columns of B in the dot-product kernel.
41constexpr int regsB = 3;
42
43/// Number of rows of A to process in the kernel. Vector loads are used for A,
44/// so we load eight times as many floats as we use registers.
45constexpr int mr = regsA * 8;
46/// Number of columns of B to process in the kernel.
47constexpr int nr = regsB;
48
49/// Blocking parameters for the outer kernel. We multiply mc x kc blocks of A
50/// with kc x nc panels of B (this approach is referred to as `gebp` in the
51/// literature). TODO: Generalize these parameters for other cache sizes.
52constexpr int mc = 256;
53constexpr int kc = 128;
54constexpr int nc = 4096;
55
56/// Compute a RAxRB block of C using a vectorized dot product, where RA is the
57/// number of registers to load from matrix A, and RB is the number of registers
58/// to load from matrix B.
59template <size_t regsA, size_t regsB>
60void libjit_matmul_dot(size_t k, const float *a, size_t lda, const float *b,
61 size_t ldb, float *c, size_t ldc) {
62 float8 csum[regsA][regsB] = {{0.0}};
63 for (size_t p = 0; p < k; p++) {
64 // Perform the DOT product.
65 for (size_t ai = 0; ai < regsA; ai++) {
66 float8 aa = LoaduFloat8(&A(ai * 8, p));
67 for (size_t bi = 0; bi < regsB; bi++) {
68 float8 bb = BroadcastFloat8(B(p, bi));
69 csum[ai][bi] += aa * bb;
70 }
71 }
72 }
73
74 // Accumulate the results into C.
75 for (size_t bi = 0; bi < regsB; bi++) {
76 for (size_t ai = 0; ai < regsA; ai++) {
77 AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
78 }
79 }
80}
81
82/// Similar to libjit_matmul_dot, but assumes that \p a and \p b have been
83/// packed using z-ordering.
84template <size_t regsA, size_t regsB>
85void libjit_matmul_zdot(size_t k, const float *a, size_t lda, const float *b,
86 size_t ldb, float *c, size_t ldc) {
87 float8 csum[regsA][regsB] = {{0.0}};
88
89 for (size_t p = 0; p < k; p++) {
90 // Perform the DOT product.
91 float8 *aptr = (float8 *)&A(0, p);
92 for (size_t ai = 0; ai < regsA; ai++) {
93 float8 aa = *aptr++;
94 for (size_t bi = 0; bi < regsB; bi++) {
95 float8 bb = BroadcastFloat8(*(b + bi));
96 csum[ai][bi] += aa * bb;
97 }
98 }
99 b += regsB;
100 }
101
102 // Accumulate the results into C.
103 for (size_t bi = 0; bi < regsB; bi++) {
104 for (size_t ai = 0; ai < regsA; ai++) {
105 AdduFloat8(&C(ai * 8, bi), csum[ai][bi]);
106 }
107 }
108}
109
110/// Pack matrix \p a into matrix \p a_to using a z-ordering, so that the
111/// dot-product kernel can stride sequentially through memory.
112template <size_t regsA>
113void pack_matrix_a(size_t m, size_t k, const float *a, size_t lda,
114 float *a_to) {
115 for (int i = 0; i < int(m) - mr + 1; i += mr) {
116 for (size_t j = 0; j < k; j++) {
117 const float *a_ij_pntr = &A(i, j);
118 for (size_t ai = 0; ai < regsA; ai++) {
119 StoreuFloat8(a_to + 8 * ai, LoaduFloat8(a_ij_pntr + 8 * ai));
120 }
121 a_to += 8 * regsA;
122 }
123 }
124}
125
126/// Pack matrix \p b into matrix \p b_to using a z-ordering, so that the
127/// dot-product kernel can stride sequentially through memory, rather than
128/// reading from `regsB` separate columns.
129template <size_t regsB>
130void pack_matrix_b(size_t n, size_t k, const float *b, size_t ldb,
131 float *b_to) {
132 for (int j = 0; j < int(n) - nr + 1; j += nr) {
133 for (size_t i = 0; i < k; i++) {
134 for (size_t bi = 0; bi < regsB; bi++) {
135 *b_to++ = B(i, j + bi);
136 }
137 }
138 }
139}
140
141/// Inner kernel for packed matrices. The order of the M and N loops matters,
142/// because packed matrices need to be more more sensitive to cache locality,
143/// and N strides over the B matrix, which is very large and will blow out the
144/// cache.
145void libjit_matmul_inner_packed(int m, int n, int k, const float *packedA,
146 const float *packedB, float *c, int ldc) {
147 for (int j = 0; j < n - nr + 1; j += nr) {
148 for (int i = 0; i < m - mr + 1; i += mr) {
149 libjit_matmul_zdot<regsA, regsB>(k, &packedA[i * k], mr, &packedB[j * k],
150 k, &C(i, j), ldc);
151 }
152 }
153}
154
155/// Inner kernel for non-packed matrices. In these cases N is small, so it
156/// tends to be beneficial to retain locality in the A matrix.
157void libjit_matmul_inner_unpacked(int m, int n, int k, const float *a, int lda,
158 const float *b, int ldb, float *c, int ldc) {
159 for (int i = 0; i < m - mr + 1; i += mr) {
160 for (int j = 0; j < n - nr + 1; j += nr) {
161 libjit_matmul_dot<regsA, regsB>(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
162 ldc);
163 }
164 }
165}
166
167/// Compute a portion of C one block at a time. Handle ragged edges with calls
168/// to a slow but general helper.
169template <bool pack>
170void libjit_matmul_inner(int m, int n, int k, const float *a, int lda,
171 const float *b, int ldb, float *c, int ldc,
172 float *packedB) {
173 // The tiling scheme naturally divides the input matrices into 2 parts each;
174 // one tiled section, and three "ragged" edges.
175 //
176 // -------------------- -------
177 // | A00*B00 | A00*B01| | A00 | -------------
178 // -------------------- += ------- * | B00 | B01 |
179 // | A10*B00 | A10*B01| | A10 | -------------
180 // -------------------- -------
181 //
182 // We can process this as 4 separate matrix multiplications. A00*B00 is the
183 // perfectly-tiled portion, which we handly with a 4x16 dot-product kernel.
184 // The ragged edges are (ideally) less critical, so we handle them with a call
185 // to a general matrix-multiplication for odd sizes.
186 float packedA[m * k] __attribute__((aligned(64)));
187 if (pack) {
188 pack_matrix_a<regsA>(m, k, &A(0, 0), lda, packedA);
189 }
190
191 if (pack) {
192 libjit_matmul_inner_packed(m, n, k, packedA, packedB, c, ldc);
193 } else {
194 libjit_matmul_inner_unpacked(m, n, k, a, lda, b, ldb, c, ldc);
195 }
196
197 sdim_t i = (m / mr) * mr;
198 sdim_t j = (n / nr) * nr;
199 if (i < m) {
200 libjit_matmul_odd(m - i, j, k, &A(i, 0), lda, &B(0, 0), ldb, &C(i, 0), ldc);
201 }
202 if (j < n) {
203 libjit_matmul_odd(i, n - j, k, &A(0, 0), lda, &B(0, j), ldb, &C(0, j), ldc);
204 }
205 if (i < m && j < n) {
206 libjit_matmul_odd(m - i, n - j, k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j),
207 ldc);
208 }
209}
210
211/// Tile A into mc * kc blocks, where mc and kc are chosen to approximately fit
212/// the L2 cache on recent Intel processors (e.g., 256 KB for Skylake). Stream
213/// kc * n panels of B through memory to compute each mc * n block of C.
214/// \p a is an \p m x \p k column-major matrix;
215/// \p b is a \p k x \p n column-major matrix;
216/// \p c is a \p m x \p n column-major matrix.
217/// \p lda, \p ldb, and \p ldc are the leading dimensions of A, B, and C,
218/// respectively.
219template <bool pack>
220void __attribute__((noinline))
221libjit_matmul_outer(dim_t m, dim_t n, dim_t k, const float *a, dim_t lda,
222 const float *b, dim_t ldb, float *c, dim_t ldc) {
223 float *packedB = nullptr;
224 if (pack) {
225 libjit_aligned_malloc((void **)&packedB, 64, kc * nc);
226 }
227
228 for (dim_t p = 0; p < k; p += kc) {
229 dim_t pb = MIN(k - p, kc);
230 for (dim_t j = 0; j < n; j += nc) {
231 dim_t jb = MIN(n - j, nc);
232 if (pack) {
233 pack_matrix_b<regsB>(jb, pb, &B(p, j), ldb, packedB);
234 }
235 for (dim_t i = 0; i < m; i += mc) {
236 dim_t ib = MIN(m - i, mc);
237 libjit_matmul_inner<pack>(ib, jb, pb, &A(i, p), lda, &B(p, j), ldb,
238 &C(i, j), ldc, packedB);
239 }
240 }
241 }
242
243 if (pack) {
244 libjit_aligned_free(packedB);
245 }
246}
247
248#undef C
249#undef B
250#undef A
251
252/// Generic template for FullyConnected. The template allows choosing the
253/// element type and bias type.
254template <typename ElemTy, typename BiasElemTy>
255void libjit_fc_generic(ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW,
256 const BiasElemTy *biasW, const dim_t *outWdims,
257 const dim_t *inWdims, const dim_t *weightsWdims,
258 const dim_t *biasWdims, int32_t outOffset,
259 int32_t inOffset, int32_t weightsOffset,
260 int32_t biasOffset, int32_t biasPre, int32_t biasPost,
261 int32_t biasScale, int32_t outPre, int32_t outPost,
262 int32_t outScale) {
263 dim_t in_w = inWdims[1];
264 dim_t out_h = outWdims[0];
265 dim_t out_w = outWdims[1];
266 for (size_t i = 0; i < out_h; i++) {
267 for (size_t j = 0; j < out_w; j++) {
268 int32_t sum = libjit_scale<int32_t>(biasW[j] - biasOffset, biasPre,
269 biasPost, biasScale, 0);
270 for (size_t k = 0; k < in_w; k++) {
271 int32_t I = inW[libjit_getXY(inWdims, i, k)];
272 int32_t W = weightsW[libjit_getXY(weightsWdims, k, j)];
273 sum += (I - inOffset) * (W - weightsOffset);
274 }
275 int32_t scaledSum =
276 libjit_scale<int32_t>(sum, outPre, outPost, outScale, outOffset);
277 outW[libjit_getXY(outWdims, i, j)] = libjit_clip_i8(scaledSum);
278 }
279 }
280}
281
282/// Generic template for rowwise quantized FullyConnected. The template allows
283/// choosing element type and bias type.
284template <typename ElemTy, typename BiasElemTy>
285void libjit_rowwise_quantized_fc_generic(
286 ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW,
287 const BiasElemTy *biasW, const int32_t *weightsOffsets,
288 const int32_t *biasPre, const int32_t *biasPost, const int32_t *biasScale,
289 const int32_t *outPre, const int32_t *outPost, const int32_t *outScale,
290 const dim_t *outWdims, const dim_t *inWdims, const dim_t *weightsWdims,
291 const dim_t *biasWdims, dim_t rowNum, int32_t outOffset, int32_t inOffset,
292 int32_t biasOffset) {
293 dim_t in_w = inWdims[1];
294 dim_t out_h = outWdims[0];
295 dim_t out_w = outWdims[1];
296
297 // In rowwise quantized FC, weights is not pretransposed : I * Tranpose(W) +
298 // B. out(i, j) = in(i, 0) * weights(j, 0) + in(i, 1) * weights(j, 1) + ... +
299 // in(i, k) * weights(j, k) + bias(j);
300 for (size_t i = 0; i < out_h; i++) {
301 for (size_t j = 0; j < out_w; j++) {
302 int32_t sum = 0;
303 for (size_t k = 0; k < in_w; k++) {
304 int32_t W = weightsW[libjit_getXY(weightsWdims, j, k)];
305 int32_t I = inW[libjit_getXY(inWdims, i, k)];
306 sum += (W - weightsOffsets[j]) * (I - inOffset);
307 }
308 int32_t B = libjit_scale<int32_t>(biasW[j] - biasOffset, biasPre[j],
309 biasPost[j], biasScale[j], 0);
310 sum += B;
311 int32_t scaledSum = libjit_scale<int32_t>(sum, outPre[j], outPost[j],
312 outScale[j], outOffset);
313 outW[libjit_getXY(outWdims, i, j)] = libjit_clip_i8(scaledSum);
314 }
315 }
316}
317} // namespace
318
319extern "C" {
320
321/// Performs the matrix multiplication c = a * b, where c, a, and b are
322/// row-major matrices.
323/// \p c is a m x n matrix, so \p cDims = {m, n}
324/// \p a is a m x k matrix, so \p aDims = {m, k}
325/// \p b is a k x n matrix, so \p bDims = {k, n}
326void libjit_matmul_f(float *c, const float *a, const float *b,
327 const dim_t *cDims, const dim_t *aDims,
328 const dim_t *bDims) {
329 memset(c, 0, cDims[0] * cDims[1] * sizeof(float));
330 // Call the matrix multiplication routine with appropriate dimensions and
331 // leading dimensions. The "leading dimension" for a row-major matrix is equal
332 // to the number of columns in the matrix. For a, this is k; for b and c,
333 // this is n.
334 //
335 // This "outer" helper assumes the matrices are given in column-major format
336 // (the packing algorithm is more effective with column-major matrices), while
337 // the input is row-major. So we compute C += B * A, which is equivalent.
338 //
339 // The matrix multiplication routine is heavily inspired by:
340 // https://github.com/flame/how-to-optimize-gemm
341 int m = cDims[1];
342 int n = cDims[0];
343 int k = aDims[1];
344
345 // Use the unpacked version which does not use extra HEAP or STACK which
346 // makes the memory usage predictable. This is very useful when building
347 // bundles (AOT) for MCU targets where the HEAP and STACK are relatively
348 // limited in size. By avoiding heap/stack usage the memory consumption
349 // is controlled and perfectly known (e.g. printed in the bundle API).
350 libjit_matmul_outer<false>(m, n, k, b, bDims[1], a, aDims[1], c, cDims[1]);
351}
352
353void libjit_matmul_i8(int8_t *outW, const int8_t *lhsW, const int8_t *rhsW,
354 const dim_t *outWdims, const dim_t *lhsWdims,
355 const dim_t *rhsWdims, int32_t outOffset,
356 int32_t lhsOffset, int32_t rhsOffset, int32_t outPre,
357 int32_t outPost, int32_t outScale) {
358 for (dim_t x = 0; x < outWdims[0]; x++) {
359 for (dim_t y = 0; y < outWdims[1]; y++) {
360 int32_t sum = 0;
361 for (dim_t i = 0; i < lhsWdims[1]; i++) {
362 int32_t lhs = lhsW[libjit_getXY(lhsWdims, x, i)] - lhsOffset;
363 int32_t rhs = rhsW[libjit_getXY(rhsWdims, i, y)] - rhsOffset;
364 sum += lhs * rhs;
365 }
366 int32_t s =
367 libjit_scale<int32_t>(sum, outPre, outPost, outScale, outOffset);
368 outW[libjit_getXY(outWdims, x, y)] = libjit_clip_i8(s);
369 }
370 }
371}
372
373/// FullyConnected with float precision.
374void libjit_fc_f(float *outW, const float *inW, const float *weightsW,
375 const float *biasW, const dim_t *outWdims,
376 const dim_t *inWdims, const dim_t *weightsWdims,
377 const dim_t *biasWdims) {
378 dim_t in_w = inWdims[1];
379 dim_t out_h = outWdims[0];
380 dim_t out_w = outWdims[1];
381 for (size_t i = 0; i < out_h; i++) {
382 for (size_t j = 0; j < out_w; j++) {
383 float sum = biasW[j];
384 for (size_t k = 0; k < in_w; k++) {
385 float I = inW[libjit_getXY(inWdims, i, k)];
386 float W = weightsW[libjit_getXY(weightsWdims, k, j)];
387 sum += I * W;
388 }
389 outW[libjit_getXY(outWdims, i, j)] = sum;
390 }
391 }
392}
393
394/// FullyConnected with int8 precision and int32 bias.
395void libjit_fc_i8_i32(int8_t *outW, const int8_t *inW, const int8_t *weightsW,
396 const int32_t *biasW, const dim_t *outWdims,
397 const dim_t *inWdims, const dim_t *weightsWdims,
398 const dim_t *biasWdims, int32_t outOffset,
399 int32_t inOffset, int32_t weightsOffset,
400 int32_t biasOffset, int32_t biasPre, int32_t biasPost,
401 int32_t biasScale, int32_t outPre, int32_t outPost,
402 int32_t outScale) {
403 libjit_fc_generic<int8_t, int32_t>(
404 outW, inW, weightsW, biasW, outWdims, inWdims, weightsWdims, biasWdims,
405 outOffset, inOffset, weightsOffset, biasOffset, biasPre, biasPost,
406 biasScale, outPre, outPost, outScale);
407}
408
409/// FullyConnected with int8 precision and int8 bias.
410void libjit_fc_i8_i8(int8_t *outW, const int8_t *inW, const int8_t *weightsW,
411 const int8_t *biasW, const dim_t *outWdims,
412 const dim_t *inWdims, const dim_t *weightsWdims,
413 const dim_t *biasWdims, int32_t outOffset,
414 int32_t inOffset, int32_t weightsOffset,
415 int32_t biasOffset, int32_t biasPre, int32_t biasPost,
416 int32_t biasScale, int32_t outPre, int32_t outPost,
417 int32_t outScale) {
418 libjit_fc_generic<int8_t, int8_t>(
419 outW, inW, weightsW, biasW, outWdims, inWdims, weightsWdims, biasWdims,
420 outOffset, inOffset, weightsOffset, biasOffset, biasPre, biasPost,
421 biasScale, outPre, outPost, outScale);
422}
423
424/// Rowwise quantized FullyConnected with int8 precision and int32 bias.
425void libjit_rowwise_quantized_fc_i8_i32(
426 int8_t *outW, const int8_t *inW, const int8_t *weightsW,
427 const int32_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
428 const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
429 const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
430 const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
431 dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
432 libjit_rowwise_quantized_fc_generic<int8_t, int32_t>(
433 outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
434 outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
435 rowNum, outOffset, inOffset, biasOffset);
436}
437
438/// Rowwise quantized FullyConnected with int8 precision and int8 bias.
439void libjit_rowwise_quantized_fc_i8_i8(
440 int8_t *outW, const int8_t *inW, const int8_t *weightsW,
441 const int8_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre,
442 const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre,
443 const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims,
444 const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims,
445 dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) {
446 libjit_rowwise_quantized_fc_generic<int8_t, int8_t>(
447 outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale,
448 outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims,
449 rowNum, outOffset, inOffset, biasOffset);
450}
451}
452