1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "libjit_defs.h" |
17 | |
18 | namespace { |
19 | |
20 | /// Macros for accessing submatrices of a matmul using the leading dimension. |
21 | #define A(i, j) a[(j)*lda + (i)] |
22 | #define B(i, j) b[(j)*ldb + (i)] |
23 | #define C(i, j) c[(j)*ldc + (i)] |
24 | |
25 | /// Naive gemm helper to handle oddly-sized matrices. |
26 | void libjit_matmul_odd(int m, int n, int k, const float *a, int lda, |
27 | const float *b, int ldb, float *c, int ldc) { |
28 | // The order of these loops is tuned for column-major matrices. |
29 | for (int p = 0; p < k; p++) { |
30 | for (int j = 0; j < n; j++) { |
31 | for (int i = 0; i < m; i++) { |
32 | C(i, j) += A(i, p) * B(p, j); |
33 | } |
34 | } |
35 | } |
36 | } |
37 | |
38 | /// Number of registers to use for rows of A in the dot-product kernel. |
39 | constexpr int regsA = 4; |
40 | /// Number of registers to use for columns of B in the dot-product kernel. |
41 | constexpr int regsB = 3; |
42 | |
43 | /// Number of rows of A to process in the kernel. Vector loads are used for A, |
44 | /// so we load eight times as many floats as we use registers. |
45 | constexpr int mr = regsA * 8; |
46 | /// Number of columns of B to process in the kernel. |
47 | constexpr int nr = regsB; |
48 | |
49 | /// Blocking parameters for the outer kernel. We multiply mc x kc blocks of A |
50 | /// with kc x nc panels of B (this approach is referred to as `gebp` in the |
51 | /// literature). TODO: Generalize these parameters for other cache sizes. |
52 | constexpr int mc = 256; |
53 | constexpr int kc = 128; |
54 | constexpr int nc = 4096; |
55 | |
56 | /// Compute a RAxRB block of C using a vectorized dot product, where RA is the |
57 | /// number of registers to load from matrix A, and RB is the number of registers |
58 | /// to load from matrix B. |
59 | template <size_t regsA, size_t regsB> |
60 | void libjit_matmul_dot(size_t k, const float *a, size_t lda, const float *b, |
61 | size_t ldb, float *c, size_t ldc) { |
62 | float8 csum[regsA][regsB] = {{0.0}}; |
63 | for (size_t p = 0; p < k; p++) { |
64 | // Perform the DOT product. |
65 | for (size_t ai = 0; ai < regsA; ai++) { |
66 | float8 aa = LoaduFloat8(&A(ai * 8, p)); |
67 | for (size_t bi = 0; bi < regsB; bi++) { |
68 | float8 bb = BroadcastFloat8(B(p, bi)); |
69 | csum[ai][bi] += aa * bb; |
70 | } |
71 | } |
72 | } |
73 | |
74 | // Accumulate the results into C. |
75 | for (size_t bi = 0; bi < regsB; bi++) { |
76 | for (size_t ai = 0; ai < regsA; ai++) { |
77 | AdduFloat8(&C(ai * 8, bi), csum[ai][bi]); |
78 | } |
79 | } |
80 | } |
81 | |
82 | /// Similar to libjit_matmul_dot, but assumes that \p a and \p b have been |
83 | /// packed using z-ordering. |
84 | template <size_t regsA, size_t regsB> |
85 | void libjit_matmul_zdot(size_t k, const float *a, size_t lda, const float *b, |
86 | size_t ldb, float *c, size_t ldc) { |
87 | float8 csum[regsA][regsB] = {{0.0}}; |
88 | |
89 | for (size_t p = 0; p < k; p++) { |
90 | // Perform the DOT product. |
91 | float8 *aptr = (float8 *)&A(0, p); |
92 | for (size_t ai = 0; ai < regsA; ai++) { |
93 | float8 aa = *aptr++; |
94 | for (size_t bi = 0; bi < regsB; bi++) { |
95 | float8 bb = BroadcastFloat8(*(b + bi)); |
96 | csum[ai][bi] += aa * bb; |
97 | } |
98 | } |
99 | b += regsB; |
100 | } |
101 | |
102 | // Accumulate the results into C. |
103 | for (size_t bi = 0; bi < regsB; bi++) { |
104 | for (size_t ai = 0; ai < regsA; ai++) { |
105 | AdduFloat8(&C(ai * 8, bi), csum[ai][bi]); |
106 | } |
107 | } |
108 | } |
109 | |
110 | /// Pack matrix \p a into matrix \p a_to using a z-ordering, so that the |
111 | /// dot-product kernel can stride sequentially through memory. |
112 | template <size_t regsA> |
113 | void pack_matrix_a(size_t m, size_t k, const float *a, size_t lda, |
114 | float *a_to) { |
115 | for (int i = 0; i < int(m) - mr + 1; i += mr) { |
116 | for (size_t j = 0; j < k; j++) { |
117 | const float *a_ij_pntr = &A(i, j); |
118 | for (size_t ai = 0; ai < regsA; ai++) { |
119 | StoreuFloat8(a_to + 8 * ai, LoaduFloat8(a_ij_pntr + 8 * ai)); |
120 | } |
121 | a_to += 8 * regsA; |
122 | } |
123 | } |
124 | } |
125 | |
126 | /// Pack matrix \p b into matrix \p b_to using a z-ordering, so that the |
127 | /// dot-product kernel can stride sequentially through memory, rather than |
128 | /// reading from `regsB` separate columns. |
129 | template <size_t regsB> |
130 | void pack_matrix_b(size_t n, size_t k, const float *b, size_t ldb, |
131 | float *b_to) { |
132 | for (int j = 0; j < int(n) - nr + 1; j += nr) { |
133 | for (size_t i = 0; i < k; i++) { |
134 | for (size_t bi = 0; bi < regsB; bi++) { |
135 | *b_to++ = B(i, j + bi); |
136 | } |
137 | } |
138 | } |
139 | } |
140 | |
141 | /// Inner kernel for packed matrices. The order of the M and N loops matters, |
142 | /// because packed matrices need to be more more sensitive to cache locality, |
143 | /// and N strides over the B matrix, which is very large and will blow out the |
144 | /// cache. |
145 | void libjit_matmul_inner_packed(int m, int n, int k, const float *packedA, |
146 | const float *packedB, float *c, int ldc) { |
147 | for (int j = 0; j < n - nr + 1; j += nr) { |
148 | for (int i = 0; i < m - mr + 1; i += mr) { |
149 | libjit_matmul_zdot<regsA, regsB>(k, &packedA[i * k], mr, &packedB[j * k], |
150 | k, &C(i, j), ldc); |
151 | } |
152 | } |
153 | } |
154 | |
155 | /// Inner kernel for non-packed matrices. In these cases N is small, so it |
156 | /// tends to be beneficial to retain locality in the A matrix. |
157 | void libjit_matmul_inner_unpacked(int m, int n, int k, const float *a, int lda, |
158 | const float *b, int ldb, float *c, int ldc) { |
159 | for (int i = 0; i < m - mr + 1; i += mr) { |
160 | for (int j = 0; j < n - nr + 1; j += nr) { |
161 | libjit_matmul_dot<regsA, regsB>(k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j), |
162 | ldc); |
163 | } |
164 | } |
165 | } |
166 | |
167 | /// Compute a portion of C one block at a time. Handle ragged edges with calls |
168 | /// to a slow but general helper. |
169 | template <bool pack> |
170 | void libjit_matmul_inner(int m, int n, int k, const float *a, int lda, |
171 | const float *b, int ldb, float *c, int ldc, |
172 | float *packedB) { |
173 | // The tiling scheme naturally divides the input matrices into 2 parts each; |
174 | // one tiled section, and three "ragged" edges. |
175 | // |
176 | // -------------------- ------- |
177 | // | A00*B00 | A00*B01| | A00 | ------------- |
178 | // -------------------- += ------- * | B00 | B01 | |
179 | // | A10*B00 | A10*B01| | A10 | ------------- |
180 | // -------------------- ------- |
181 | // |
182 | // We can process this as 4 separate matrix multiplications. A00*B00 is the |
183 | // perfectly-tiled portion, which we handly with a 4x16 dot-product kernel. |
184 | // The ragged edges are (ideally) less critical, so we handle them with a call |
185 | // to a general matrix-multiplication for odd sizes. |
186 | float packedA[m * k] __attribute__((aligned(64))); |
187 | if (pack) { |
188 | pack_matrix_a<regsA>(m, k, &A(0, 0), lda, packedA); |
189 | } |
190 | |
191 | if (pack) { |
192 | libjit_matmul_inner_packed(m, n, k, packedA, packedB, c, ldc); |
193 | } else { |
194 | libjit_matmul_inner_unpacked(m, n, k, a, lda, b, ldb, c, ldc); |
195 | } |
196 | |
197 | sdim_t i = (m / mr) * mr; |
198 | sdim_t j = (n / nr) * nr; |
199 | if (i < m) { |
200 | libjit_matmul_odd(m - i, j, k, &A(i, 0), lda, &B(0, 0), ldb, &C(i, 0), ldc); |
201 | } |
202 | if (j < n) { |
203 | libjit_matmul_odd(i, n - j, k, &A(0, 0), lda, &B(0, j), ldb, &C(0, j), ldc); |
204 | } |
205 | if (i < m && j < n) { |
206 | libjit_matmul_odd(m - i, n - j, k, &A(i, 0), lda, &B(0, j), ldb, &C(i, j), |
207 | ldc); |
208 | } |
209 | } |
210 | |
211 | /// Tile A into mc * kc blocks, where mc and kc are chosen to approximately fit |
212 | /// the L2 cache on recent Intel processors (e.g., 256 KB for Skylake). Stream |
213 | /// kc * n panels of B through memory to compute each mc * n block of C. |
214 | /// \p a is an \p m x \p k column-major matrix; |
215 | /// \p b is a \p k x \p n column-major matrix; |
216 | /// \p c is a \p m x \p n column-major matrix. |
217 | /// \p lda, \p ldb, and \p ldc are the leading dimensions of A, B, and C, |
218 | /// respectively. |
219 | template <bool pack> |
220 | void __attribute__((noinline)) |
221 | libjit_matmul_outer(dim_t m, dim_t n, dim_t k, const float *a, dim_t lda, |
222 | const float *b, dim_t ldb, float *c, dim_t ldc) { |
223 | float *packedB = nullptr; |
224 | if (pack) { |
225 | libjit_aligned_malloc((void **)&packedB, 64, kc * nc); |
226 | } |
227 | |
228 | for (dim_t p = 0; p < k; p += kc) { |
229 | dim_t pb = MIN(k - p, kc); |
230 | for (dim_t j = 0; j < n; j += nc) { |
231 | dim_t jb = MIN(n - j, nc); |
232 | if (pack) { |
233 | pack_matrix_b<regsB>(jb, pb, &B(p, j), ldb, packedB); |
234 | } |
235 | for (dim_t i = 0; i < m; i += mc) { |
236 | dim_t ib = MIN(m - i, mc); |
237 | libjit_matmul_inner<pack>(ib, jb, pb, &A(i, p), lda, &B(p, j), ldb, |
238 | &C(i, j), ldc, packedB); |
239 | } |
240 | } |
241 | } |
242 | |
243 | if (pack) { |
244 | libjit_aligned_free(packedB); |
245 | } |
246 | } |
247 | |
248 | #undef C |
249 | #undef B |
250 | #undef A |
251 | |
252 | /// Generic template for FullyConnected. The template allows choosing the |
253 | /// element type and bias type. |
254 | template <typename ElemTy, typename BiasElemTy> |
255 | void libjit_fc_generic(ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW, |
256 | const BiasElemTy *biasW, const dim_t *outWdims, |
257 | const dim_t *inWdims, const dim_t *weightsWdims, |
258 | const dim_t *biasWdims, int32_t outOffset, |
259 | int32_t inOffset, int32_t weightsOffset, |
260 | int32_t biasOffset, int32_t biasPre, int32_t biasPost, |
261 | int32_t biasScale, int32_t outPre, int32_t outPost, |
262 | int32_t outScale) { |
263 | dim_t in_w = inWdims[1]; |
264 | dim_t out_h = outWdims[0]; |
265 | dim_t out_w = outWdims[1]; |
266 | for (size_t i = 0; i < out_h; i++) { |
267 | for (size_t j = 0; j < out_w; j++) { |
268 | int32_t sum = libjit_scale<int32_t>(biasW[j] - biasOffset, biasPre, |
269 | biasPost, biasScale, 0); |
270 | for (size_t k = 0; k < in_w; k++) { |
271 | int32_t I = inW[libjit_getXY(inWdims, i, k)]; |
272 | int32_t W = weightsW[libjit_getXY(weightsWdims, k, j)]; |
273 | sum += (I - inOffset) * (W - weightsOffset); |
274 | } |
275 | int32_t scaledSum = |
276 | libjit_scale<int32_t>(sum, outPre, outPost, outScale, outOffset); |
277 | outW[libjit_getXY(outWdims, i, j)] = libjit_clip_i8(scaledSum); |
278 | } |
279 | } |
280 | } |
281 | |
282 | /// Generic template for rowwise quantized FullyConnected. The template allows |
283 | /// choosing element type and bias type. |
284 | template <typename ElemTy, typename BiasElemTy> |
285 | void libjit_rowwise_quantized_fc_generic( |
286 | ElemTy *outW, const ElemTy *inW, const ElemTy *weightsW, |
287 | const BiasElemTy *biasW, const int32_t *weightsOffsets, |
288 | const int32_t *biasPre, const int32_t *biasPost, const int32_t *biasScale, |
289 | const int32_t *outPre, const int32_t *outPost, const int32_t *outScale, |
290 | const dim_t *outWdims, const dim_t *inWdims, const dim_t *weightsWdims, |
291 | const dim_t *biasWdims, dim_t rowNum, int32_t outOffset, int32_t inOffset, |
292 | int32_t biasOffset) { |
293 | dim_t in_w = inWdims[1]; |
294 | dim_t out_h = outWdims[0]; |
295 | dim_t out_w = outWdims[1]; |
296 | |
297 | // In rowwise quantized FC, weights is not pretransposed : I * Tranpose(W) + |
298 | // B. out(i, j) = in(i, 0) * weights(j, 0) + in(i, 1) * weights(j, 1) + ... + |
299 | // in(i, k) * weights(j, k) + bias(j); |
300 | for (size_t i = 0; i < out_h; i++) { |
301 | for (size_t j = 0; j < out_w; j++) { |
302 | int32_t sum = 0; |
303 | for (size_t k = 0; k < in_w; k++) { |
304 | int32_t W = weightsW[libjit_getXY(weightsWdims, j, k)]; |
305 | int32_t I = inW[libjit_getXY(inWdims, i, k)]; |
306 | sum += (W - weightsOffsets[j]) * (I - inOffset); |
307 | } |
308 | int32_t B = libjit_scale<int32_t>(biasW[j] - biasOffset, biasPre[j], |
309 | biasPost[j], biasScale[j], 0); |
310 | sum += B; |
311 | int32_t scaledSum = libjit_scale<int32_t>(sum, outPre[j], outPost[j], |
312 | outScale[j], outOffset); |
313 | outW[libjit_getXY(outWdims, i, j)] = libjit_clip_i8(scaledSum); |
314 | } |
315 | } |
316 | } |
317 | } // namespace |
318 | |
319 | extern "C" { |
320 | |
321 | /// Performs the matrix multiplication c = a * b, where c, a, and b are |
322 | /// row-major matrices. |
323 | /// \p c is a m x n matrix, so \p cDims = {m, n} |
324 | /// \p a is a m x k matrix, so \p aDims = {m, k} |
325 | /// \p b is a k x n matrix, so \p bDims = {k, n} |
326 | void libjit_matmul_f(float *c, const float *a, const float *b, |
327 | const dim_t *cDims, const dim_t *aDims, |
328 | const dim_t *bDims) { |
329 | memset(c, 0, cDims[0] * cDims[1] * sizeof(float)); |
330 | // Call the matrix multiplication routine with appropriate dimensions and |
331 | // leading dimensions. The "leading dimension" for a row-major matrix is equal |
332 | // to the number of columns in the matrix. For a, this is k; for b and c, |
333 | // this is n. |
334 | // |
335 | // This "outer" helper assumes the matrices are given in column-major format |
336 | // (the packing algorithm is more effective with column-major matrices), while |
337 | // the input is row-major. So we compute C += B * A, which is equivalent. |
338 | // |
339 | // The matrix multiplication routine is heavily inspired by: |
340 | // https://github.com/flame/how-to-optimize-gemm |
341 | int m = cDims[1]; |
342 | int n = cDims[0]; |
343 | int k = aDims[1]; |
344 | |
345 | // Use the unpacked version which does not use extra HEAP or STACK which |
346 | // makes the memory usage predictable. This is very useful when building |
347 | // bundles (AOT) for MCU targets where the HEAP and STACK are relatively |
348 | // limited in size. By avoiding heap/stack usage the memory consumption |
349 | // is controlled and perfectly known (e.g. printed in the bundle API). |
350 | libjit_matmul_outer<false>(m, n, k, b, bDims[1], a, aDims[1], c, cDims[1]); |
351 | } |
352 | |
353 | void libjit_matmul_i8(int8_t *outW, const int8_t *lhsW, const int8_t *rhsW, |
354 | const dim_t *outWdims, const dim_t *lhsWdims, |
355 | const dim_t *rhsWdims, int32_t outOffset, |
356 | int32_t lhsOffset, int32_t rhsOffset, int32_t outPre, |
357 | int32_t outPost, int32_t outScale) { |
358 | for (dim_t x = 0; x < outWdims[0]; x++) { |
359 | for (dim_t y = 0; y < outWdims[1]; y++) { |
360 | int32_t sum = 0; |
361 | for (dim_t i = 0; i < lhsWdims[1]; i++) { |
362 | int32_t lhs = lhsW[libjit_getXY(lhsWdims, x, i)] - lhsOffset; |
363 | int32_t rhs = rhsW[libjit_getXY(rhsWdims, i, y)] - rhsOffset; |
364 | sum += lhs * rhs; |
365 | } |
366 | int32_t s = |
367 | libjit_scale<int32_t>(sum, outPre, outPost, outScale, outOffset); |
368 | outW[libjit_getXY(outWdims, x, y)] = libjit_clip_i8(s); |
369 | } |
370 | } |
371 | } |
372 | |
373 | /// FullyConnected with float precision. |
374 | void libjit_fc_f(float *outW, const float *inW, const float *weightsW, |
375 | const float *biasW, const dim_t *outWdims, |
376 | const dim_t *inWdims, const dim_t *weightsWdims, |
377 | const dim_t *biasWdims) { |
378 | dim_t in_w = inWdims[1]; |
379 | dim_t out_h = outWdims[0]; |
380 | dim_t out_w = outWdims[1]; |
381 | for (size_t i = 0; i < out_h; i++) { |
382 | for (size_t j = 0; j < out_w; j++) { |
383 | float sum = biasW[j]; |
384 | for (size_t k = 0; k < in_w; k++) { |
385 | float I = inW[libjit_getXY(inWdims, i, k)]; |
386 | float W = weightsW[libjit_getXY(weightsWdims, k, j)]; |
387 | sum += I * W; |
388 | } |
389 | outW[libjit_getXY(outWdims, i, j)] = sum; |
390 | } |
391 | } |
392 | } |
393 | |
394 | /// FullyConnected with int8 precision and int32 bias. |
395 | void libjit_fc_i8_i32(int8_t *outW, const int8_t *inW, const int8_t *weightsW, |
396 | const int32_t *biasW, const dim_t *outWdims, |
397 | const dim_t *inWdims, const dim_t *weightsWdims, |
398 | const dim_t *biasWdims, int32_t outOffset, |
399 | int32_t inOffset, int32_t weightsOffset, |
400 | int32_t biasOffset, int32_t biasPre, int32_t biasPost, |
401 | int32_t biasScale, int32_t outPre, int32_t outPost, |
402 | int32_t outScale) { |
403 | libjit_fc_generic<int8_t, int32_t>( |
404 | outW, inW, weightsW, biasW, outWdims, inWdims, weightsWdims, biasWdims, |
405 | outOffset, inOffset, weightsOffset, biasOffset, biasPre, biasPost, |
406 | biasScale, outPre, outPost, outScale); |
407 | } |
408 | |
409 | /// FullyConnected with int8 precision and int8 bias. |
410 | void libjit_fc_i8_i8(int8_t *outW, const int8_t *inW, const int8_t *weightsW, |
411 | const int8_t *biasW, const dim_t *outWdims, |
412 | const dim_t *inWdims, const dim_t *weightsWdims, |
413 | const dim_t *biasWdims, int32_t outOffset, |
414 | int32_t inOffset, int32_t weightsOffset, |
415 | int32_t biasOffset, int32_t biasPre, int32_t biasPost, |
416 | int32_t biasScale, int32_t outPre, int32_t outPost, |
417 | int32_t outScale) { |
418 | libjit_fc_generic<int8_t, int8_t>( |
419 | outW, inW, weightsW, biasW, outWdims, inWdims, weightsWdims, biasWdims, |
420 | outOffset, inOffset, weightsOffset, biasOffset, biasPre, biasPost, |
421 | biasScale, outPre, outPost, outScale); |
422 | } |
423 | |
424 | /// Rowwise quantized FullyConnected with int8 precision and int32 bias. |
425 | void libjit_rowwise_quantized_fc_i8_i32( |
426 | int8_t *outW, const int8_t *inW, const int8_t *weightsW, |
427 | const int32_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre, |
428 | const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre, |
429 | const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims, |
430 | const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims, |
431 | dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) { |
432 | libjit_rowwise_quantized_fc_generic<int8_t, int32_t>( |
433 | outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale, |
434 | outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims, |
435 | rowNum, outOffset, inOffset, biasOffset); |
436 | } |
437 | |
438 | /// Rowwise quantized FullyConnected with int8 precision and int8 bias. |
439 | void libjit_rowwise_quantized_fc_i8_i8( |
440 | int8_t *outW, const int8_t *inW, const int8_t *weightsW, |
441 | const int8_t *biasW, const int32_t *weightsOffsets, const int32_t *biasPre, |
442 | const int32_t *biasPost, const int32_t *biasScale, const int32_t *outPre, |
443 | const int32_t *outPost, const int32_t *outScale, const dim_t *outWdims, |
444 | const dim_t *inWdims, const dim_t *weightsWdims, const dim_t *biasWdims, |
445 | dim_t rowNum, int32_t outOffset, int32_t inOffset, int32_t biasOffset) { |
446 | libjit_rowwise_quantized_fc_generic<int8_t, int8_t>( |
447 | outW, inW, weightsW, biasW, weightsOffsets, biasPre, biasPost, biasScale, |
448 | outPre, outPost, outScale, outWdims, inWdims, weightsWdims, biasWdims, |
449 | rowNum, outOffset, inOffset, biasOffset); |
450 | } |
451 | } |
452 | |