1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/Utils.h"
9#include <cpuinfo.h>
10#include <cassert>
11#include <cinttypes>
12#include <cmath>
13#include <cstring>
14#include <iomanip>
15#include <iostream>
16#include <limits>
17#include <new>
18#include <stdexcept>
19#include <unordered_map>
20#include <unordered_set>
21
22namespace fbgemm {
23
24/**
25 * @brief Compare the reference and test result matrix to check the correctness.
26 * @param ref The buffer for the reference result matrix.
27 * @param test The buffer for the test result matrix.
28 * @param m The height of the reference and test result matrix.
29 * @param n The width of the reference and test result matrix.
30 * @param ld The leading dimension of the reference and test result matrix.
31 * @param max_mismatches_to_report The maximum number of tolerable mismatches to
32 * report.
33 * @param atol The tolerable error.
34 * @retval false If the number of mismatches for reference and test result
35 * matrix exceeds max_mismatches_to_report.
36 * @retval true If the number of mismatches for reference and test result matrix
37 * is tolerable.
38 */
39template <typename T>
40int compare_buffers(
41 const T* ref,
42 const T* test,
43 int m,
44 int n,
45 int ld,
46 size_t max_mismatches_to_report,
47 float atol /*=1e-3*/) {
48 size_t mismatches = 0;
49 for (int i = 0; i < m; ++i) {
50 for (int j = 0; j < n; ++j) {
51 T reference = ref[i * ld + j], actual = test[i * ld + j];
52 if (std::abs(reference - actual) > atol) {
53 std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl;
54 if (std::is_integral<T>::value) {
55 std::cout << "\t reference:" << static_cast<int64_t>(reference)
56 << " test:" << static_cast<int64_t>(actual) << std::endl;
57 } else {
58 std::cout << "\t reference:" << reference << " test:" << actual
59 << std::endl;
60 }
61
62 mismatches++;
63 if (mismatches > max_mismatches_to_report) {
64 return 1;
65 }
66 }
67 }
68 }
69 return 0;
70}
71
72/**
73 * @brief Print the matrix.
74 * @param op Transpose type of the matrix.
75 * @param R The height of the matrix.
76 * @param C The width of the matrix.
77 * @param ld The leading dimension of the matrix.
78 * @param name The prefix string before printing the matrix.
79 */
80template <typename T>
81void printMatrix(
82 matrix_op_t op,
83 const T* inp,
84 size_t R,
85 size_t C,
86 size_t ld,
87 std::string name) {
88 // R: number of rows in op(inp)
89 // C: number of cols in op(inp)
90 // ld: leading dimension in inp
91 std::cout << name << ":"
92 << "[" << R << ", " << C << "]" << std::endl;
93 bool tr = (op == matrix_op_t::Transpose);
94 for (size_t r = 0; r < R; ++r) {
95 for (size_t c = 0; c < C; ++c) {
96 T res = tr ? inp[c * ld + r] : inp[r * ld + c];
97 if (std::is_integral<T>::value) {
98 std::cout << std::setw(5) << static_cast<int64_t>(res) << " ";
99 } else {
100 std::cout << std::setw(5) << res << " ";
101 }
102 }
103 std::cout << std::endl;
104 }
105}
106
107template int compare_buffers<float>(
108 const float* ref,
109 const float* test,
110 int m,
111 int n,
112 int ld,
113 size_t max_mismatches_to_report,
114 float atol);
115
116template int compare_buffers<int32_t>(
117 const int32_t* ref,
118 const int32_t* test,
119 int m,
120 int n,
121 int ld,
122 size_t max_mismatches_to_report,
123 float atol);
124
125template int compare_buffers<uint8_t>(
126 const uint8_t* ref,
127 const uint8_t* test,
128 int m,
129 int n,
130 int ld,
131 size_t max_mismatches_to_report,
132 float atol);
133
134template int compare_buffers<int64_t>(
135 const int64_t* ref,
136 const int64_t* test,
137 int m,
138 int n,
139 int ld,
140 size_t max_mismatches_to_report,
141 float atol);
142
143template void printMatrix<float>(
144 matrix_op_t op,
145 const float* inp,
146 size_t R,
147 size_t C,
148 size_t ld,
149 std::string name);
150template void printMatrix<int8_t>(
151 matrix_op_t op,
152 const int8_t* inp,
153 size_t R,
154 size_t C,
155 size_t ld,
156 std::string name);
157template void printMatrix<uint8_t>(
158 matrix_op_t op,
159 const uint8_t* inp,
160 size_t R,
161 size_t C,
162 size_t ld,
163 std::string name);
164template void printMatrix<int32_t>(
165 matrix_op_t op,
166 const int32_t* inp,
167 size_t R,
168 size_t C,
169 size_t ld,
170 std::string name);
171
172namespace {
173inst_set_t g_forced_isa = inst_set_t::anyarch;
174bool g_Avx512_Ymm_enabled = false;
175
176inst_set_t fbgemmEnvGetIsa() {
177 static const char* isa_env = "FBGEMM_ENABLE_INSTRUCTIONS";
178 static const std::unordered_map<std::string, inst_set_t> isaMap = {
179 {"AVX2", inst_set_t::avx2},
180 {"AVX512", inst_set_t::avx512},
181 {"AVX512_E1", inst_set_t::avx512_vnni},
182 {"AVX512_256", inst_set_t::avx512_ymm},
183 {"AVX512_E1_256", inst_set_t::avx512_vnni_ymm},
184 };
185 const char* env = std::getenv(isa_env);
186 if (env == nullptr) {
187 return inst_set_t::anyarch;
188 }
189
190#ifdef __aarch64__
191#ifdef VLOG
192 VLOG(0) << "[" << env << "] not supported on aarch64";
193#endif
194 return inst_set_t::anyarch;
195#endif
196
197 std::string val(env);
198 std::transform(val.begin(), val.end(), val.begin(), ::toupper);
199 auto it = isaMap.find(val);
200 return it == isaMap.end() ? inst_set_t::anyarch : it->second;
201}
202
203bool fbgemmEnvAvx512_256Enabled() {
204 static const char* isa_env = "FBGEMM_ENABLE_AVX512_256";
205 const char* env = std::getenv(isa_env);
206 if (env == nullptr) {
207 return false;
208 }
209
210#ifdef __aarch64__
211#ifdef VLOG
212 VLOG(0) << "[" << env << "] not supported on aarch64";
213#endif
214 return false;
215#endif
216
217 std::string val(env);
218 std::transform(val.begin(), val.end(), val.begin(), ::tolower);
219 return val == "true" || val == "1";
220}
221
222// This is require for build by older compilers GCC 5.4 and C++11
223struct inst_set_t_hash {
224 std::size_t operator()(inst_set_t t) const {
225 return static_cast<std::size_t>(t);
226 }
227};
228
229std::unordered_map<
230 inst_set_t,
231 std::unordered_set<inst_set_t, inst_set_t_hash>,
232 inst_set_t_hash>
233 isaSupportMap = {
234 {inst_set_t::anyarch, {inst_set_t::anyarch}},
235 {inst_set_t::avx2, {inst_set_t::avx2, inst_set_t::anyarch}},
236 {inst_set_t::avx512,
237 {inst_set_t::avx512, inst_set_t::avx512_ymm, inst_set_t::avx2}},
238 {inst_set_t::avx512_ymm,
239 {inst_set_t::avx512, inst_set_t::avx512_ymm, inst_set_t::avx2}},
240 {inst_set_t::avx512_vnni,
241 {inst_set_t::avx512_vnni,
242 inst_set_t::avx512_vnni_ymm,
243 inst_set_t::avx512,
244 inst_set_t::avx512_ymm,
245 inst_set_t::avx2}},
246 {inst_set_t::avx512_vnni_ymm,
247 {inst_set_t::avx512_vnni,
248 inst_set_t::avx512_vnni_ymm,
249 inst_set_t::avx512,
250 inst_set_t::avx512_ymm,
251 inst_set_t::avx2}},
252};
253
254} // namespace
255
256/**
257 * @brief Force specific architecure to for GEMM kernel execution
258 * overides FBGEMM_ENABLE_AVX512_256 env. variable
259 * @param isa the ISA to enforce, supported optionsi
260 * AVX2 inst_set_t::avx2
261 * AVX512 inst_set_t::avx512
262 * AVX512_E1 inst_set_t::avx512_vnni
263 * AVX512_256 inst_set_t::avx512_ymm
264 * AVX512_E1_256 inst_set_t::avx512_vnni_ymm
265 */
266void fbgemmForceIsa(inst_set_t isa) {
267 g_forced_isa = isa;
268#ifdef __aarch64__
269#ifdef VLOG
270 VLOG(0) << "[anyarch] forced on aarch64";
271#endif
272 g_forced_isa = inst_set_t::anyarch;
273#endif
274}
275
276/**
277 * @brief Enables AVX512-256 if appriate. Inteded for Skylake based Xeon-D
278 * processors, wherein AXV512-256 is preferred due to higher
279 * Turbo frequencis
280 * @param flag True enables / False disables
281 */
282void fbgemmEnableAvx512Ymm(bool flag) {
283 g_Avx512_Ymm_enabled = flag;
284}
285
286/**
287 * @brief Determine the best available x86 machine ISA to be used for
288 * GEMM kernels.
289 * FBGEMM_ENABLE_AVX512_256 env. or fbgemmForceIsa() are set
290 * forces to specific architecture if supported by the processor.
291 * Enforcing on Skylake to AVX2 will execute AVX2 version of the kernel
292 * However, enforcing AVX512-256 on Broadwell will fail, and AVX2 version
293 * of the kernels will be executed.
294 */
295inst_set_t fbgemmInstructionSet() {
296 static const inst_set_t env_forced_isa = fbgemmEnvGetIsa();
297 static const bool isAvx512_Ymm_enabled = fbgemmEnvAvx512_256Enabled();
298
299 inst_set_t forced_isa =
300 g_forced_isa != inst_set_t::anyarch ? g_forced_isa : env_forced_isa;
301 static const inst_set_t detected_isa = ([]() {
302 inst_set_t isa = inst_set_t::anyarch;
303 // Check environment
304 if (cpuinfo_initialize()) {
305 const bool isXeonD = fbgemmIsIntelXeonD() &&
306 (g_Avx512_Ymm_enabled || isAvx512_Ymm_enabled);
307 if (fbgemmHasAvx512VnniSupport()) {
308 if (isXeonD) {
309 isa = inst_set_t::avx512_vnni_ymm;
310 } else {
311 isa = inst_set_t::avx512_vnni;
312 }
313 } else if (fbgemmHasAvx512Support()) {
314 if (isXeonD) {
315 isa = inst_set_t::avx512_ymm;
316 } else {
317 isa = inst_set_t::avx512;
318 }
319 } else if (fbgemmHasAvx2Support()) {
320 isa = inst_set_t::avx2;
321 }
322 }
323 return isa;
324 })();
325
326 if (forced_isa == inst_set_t::anyarch) {
327 return detected_isa;
328 }
329 const auto supported_isa = isaSupportMap.find(detected_isa);
330 assert(
331 supported_isa != isaSupportMap.end() &&
332 "Detected ISA can't be located in Supported ISA map");
333 if (supported_isa == isaSupportMap.end()) {
334 return detected_isa;
335 }
336 return supported_isa->second.count(forced_isa) ? forced_isa : detected_isa;
337}
338
339bool isZmm(inst_set_t isa) {
340 return isa == inst_set_t::avx512 || isa == inst_set_t::avx512_vnni;
341}
342
343bool isYmm(inst_set_t isa) {
344 return isa == inst_set_t::avx512_ymm || isa == inst_set_t::avx512_vnni_ymm ||
345 isa == inst_set_t::avx2;
346}
347
348bool fbgemmIsIntelXeonD() {
349 auto const pkgInfo = cpuinfo_get_packages();
350 if (strstr(pkgInfo->name, "Intel Xeon D-") ||
351 cpuinfo_get_packages_count() == 1) {
352 return true;
353 }
354 return false;
355}
356
357bool fbgemmHasAvx512Support() {
358 return (
359 cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512bw() &&
360 cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512vl());
361}
362
363bool fbgemmHasAvx2Support() {
364 return (cpuinfo_has_x86_avx2());
365}
366
367bool fbgemmHasAvx512VnniSupport() {
368 return (cpuinfo_has_x86_avx512vnni());
369}
370
371bool fbgemmHasArmNeonSupport() {
372 return (cpuinfo_has_arm_neon());
373}
374
375void fbgemmPartition1D(
376 int thread_id,
377 int num_threads,
378 int64_t total_work,
379 int64_t& start,
380 int64_t& end) {
381 // if num_threads == 0,
382 // this threads should not perform any work
383 if (num_threads == 0) {
384 start = end = 0;
385 return;
386 }
387 int64_t work_per_thread = (total_work + num_threads - 1) / num_threads;
388 start = std::min(thread_id * work_per_thread, total_work);
389 end = std::min((thread_id + 1) * work_per_thread, total_work);
390}
391
392void fbgemmPartition1DBlocked(
393 int thread_id,
394 int num_threads,
395 int64_t total_work,
396 int block_size,
397 int64_t& start,
398 int64_t& end) {
399 if (block_size == 1) {
400 return fbgemmPartition1D(thread_id, num_threads, total_work, start, end);
401 }
402 int64_t total_work_in_blocks = total_work / block_size;
403 int64_t start_block, end_block;
404 fbgemmPartition1D(
405 thread_id, num_threads, total_work_in_blocks, start_block, end_block);
406 start = std::min(start_block * block_size, total_work);
407 end = thread_id == num_threads - 1
408 ? std::max(end_block * block_size, total_work)
409 : std::min(end_block * block_size, total_work);
410}
411
412void* fbgemmAlignedAlloc(
413 size_t align,
414 size_t size,
415 bool raiseException /*=false*/) {
416 void* aligned_mem = nullptr;
417 int ret;
418#ifdef _MSC_VER
419 aligned_mem = _aligned_malloc(size, align);
420 ret = 0;
421#else
422 ret = posix_memalign(&aligned_mem, align, size);
423#endif
424 // Throw std::bad_alloc in the case of memory allocation failure.
425 if (raiseException || ret || aligned_mem == nullptr) {
426 throw std::bad_alloc();
427 }
428 return aligned_mem;
429}
430
431void fbgemmAlignedFree(void* p) {
432#ifdef _MSC_VER
433 _aligned_free(p);
434#else
435 free(p);
436#endif
437}
438
439int fbgemmGet2DPartition(
440 int m,
441 int n,
442 int nthreads,
443 int n_align,
444 double aspect_ratio) {
445 // mb: number of thread blocks within a socket along m.
446 // nb: number of thread blocks along n.
447 // mb * nb = nthreads.
448 // bm: number of rows assigned per thread block (bm = ceil(m/mb)).
449 // bn: number of cols assigned per thread block (bn = ceil(n/nb)).
450 // find mb and nb such that bm / bn is as close as possible to aspect_ratio.
451
452 // for large thread numbers, we would like to reduce the aspect_ratio ---
453 // if the matrix is short-and-fat
454 // this allows us to assign more parallelism to i-dimension
455 if (nthreads > 16 && m / n < 0.2) {
456 aspect_ratio = 0.2;
457 }
458
459 int mb = 1;
460 int nb = nthreads / mb;
461 int bm = (m + mb - 1) / mb;
462 int bn = ((n + n_align - 1) / n_align + nb - 1) / nb * n_align;
463 double best_delta = std::abs(static_cast<double>(bm) / bn - aspect_ratio);
464 for (int mb_candidate = 2; mb_candidate <= nthreads; mb_candidate++) {
465 // mb does not need to divide nthreads
466 // here nthreads % mb_candidate!=0 constraint is removed for nthreads>16
467 if (nthreads % mb_candidate != 0 && nthreads <= 16) {
468 continue;
469 }
470 int nb_candidate = nthreads / mb_candidate;
471 int bm_candidate = (m + mb_candidate - 1) / mb_candidate;
472 int bn_candidate = ((n + n_align - 1) / n_align + nb_candidate - 1) /
473 nb_candidate * n_align;
474 double delta = std::abs(
475 static_cast<double>(bm_candidate) / bn_candidate - aspect_ratio);
476 if (delta < best_delta) {
477 best_delta = delta;
478 mb = mb_candidate;
479 } else {
480 break;
481 }
482 }
483 return mb;
484}
485
486thread_type_t fbgemmGetThreadPartition(
487 int g,
488 int m,
489 int n,
490 int thread_id,
491 int num_threads,
492 int n_align) {
493 assert(num_threads >= 1);
494
495 // Fast path for the single thread case.
496 if (num_threads == 1) {
497 return thread_type_t{1, 1, 1, 0, 0, 0};
498 }
499
500 thread_type_t th_info;
501
502 // Heuristic for determine the thread partitions for parallelizing across g, m
503 // or n dimensions.
504 // TODO: more smart ways for thread partitions considering the
505 // grain size (MR, NR) parameters
506 if (g > num_threads) {
507 // TODO: when G == nthreads + 1, we'll have a big load imbalance because
508 // only one thread will get 2 groups.
509 th_info.g_num_threads = num_threads;
510 } else {
511 if (g != 0 && num_threads % g == 0) {
512 th_info.g_num_threads = g;
513 } else {
514 th_info.g_num_threads = 1;
515 }
516 }
517 num_threads /= th_info.g_num_threads;
518
519 // We favor the parallelization on the m dimension compared to the n
520 // dimension, so we set aspect_ratio to 0.5 here.
521 th_info.m_num_threads = fbgemmGet2DPartition(m, n, num_threads, n_align, 0.5);
522
523 // when num_threads >16, m_num_threads may not divide num_threads
524 if (num_threads <= 16) {
525 assert(num_threads % (th_info.m_num_threads) == 0);
526 }
527 th_info.n_num_threads = num_threads / th_info.m_num_threads;
528
529 // When there are 12 threads (num_threads = 12) and g_nthreads = 2, m_nthreads
530 // = 2, the threads will be organized as the following 2x2x3 layout (thread is
531 // partitioned in the last-dim index (i.e., n, m, g, row-major for 2D) major
532 // order):
533 //
534 // thread 0, thread 1, thread 2 thread 6, thread 7, thread 8
535 // thread 3, thread 4, thread 5 thread 9, thread 10, thread 11
536 //
537 // And the corresponding (g_thread_id, m_thread_id, n_thread_id) for
538 // each thread is listed as the following:
539 //
540 // (0, 0, 0), (0, 0, 1), (0, 0, 2) (1, 0, 0), (1, 0, 1), (1, 0, 2)
541 // (0, 1, 0), (0, 1, 1), (0, 1, 2) (1, 1, 0), (1, 1, 1), (1, 1, 2)
542
543 // thread can be inactive,
544 // meaning they are launched, but will not be assigned any work
545 if (thread_id >=
546 th_info.g_num_threads * th_info.m_num_threads * th_info.n_num_threads) {
547 th_info.m_thread_id = 0;
548 th_info.n_thread_id = 0;
549 th_info.g_thread_id = 0;
550 th_info.m_num_threads = 0;
551 th_info.n_num_threads = 0;
552 th_info.g_num_threads = 0;
553 return th_info;
554 }
555
556 // We can view the thread as the ternary with 3-dim base: {g,m,n}_num_threads.
557 th_info.n_thread_id = thread_id % th_info.n_num_threads;
558 thread_id /= th_info.n_num_threads;
559 th_info.m_thread_id = thread_id % th_info.m_num_threads;
560 thread_id /= th_info.m_num_threads;
561 th_info.g_thread_id = thread_id % th_info.g_num_threads;
562
563 return th_info;
564}
565
566} // namespace fbgemm
567