1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/Utils.h" |
9 | #include <cpuinfo.h> |
10 | #include <cassert> |
11 | #include <cinttypes> |
12 | #include <cmath> |
13 | #include <cstring> |
14 | #include <iomanip> |
15 | #include <iostream> |
16 | #include <limits> |
17 | #include <new> |
18 | #include <stdexcept> |
19 | #include <unordered_map> |
20 | #include <unordered_set> |
21 | |
22 | namespace fbgemm { |
23 | |
24 | /** |
25 | * @brief Compare the reference and test result matrix to check the correctness. |
26 | * @param ref The buffer for the reference result matrix. |
27 | * @param test The buffer for the test result matrix. |
28 | * @param m The height of the reference and test result matrix. |
29 | * @param n The width of the reference and test result matrix. |
30 | * @param ld The leading dimension of the reference and test result matrix. |
31 | * @param max_mismatches_to_report The maximum number of tolerable mismatches to |
32 | * report. |
33 | * @param atol The tolerable error. |
34 | * @retval false If the number of mismatches for reference and test result |
35 | * matrix exceeds max_mismatches_to_report. |
36 | * @retval true If the number of mismatches for reference and test result matrix |
37 | * is tolerable. |
38 | */ |
39 | template <typename T> |
40 | int compare_buffers( |
41 | const T* ref, |
42 | const T* test, |
43 | int m, |
44 | int n, |
45 | int ld, |
46 | size_t max_mismatches_to_report, |
47 | float atol /*=1e-3*/) { |
48 | size_t mismatches = 0; |
49 | for (int i = 0; i < m; ++i) { |
50 | for (int j = 0; j < n; ++j) { |
51 | T reference = ref[i * ld + j], actual = test[i * ld + j]; |
52 | if (std::abs(reference - actual) > atol) { |
53 | std::cout << "\tmismatch at (" << i << ", " << j << ")" << std::endl; |
54 | if (std::is_integral<T>::value) { |
55 | std::cout << "\t reference:" << static_cast<int64_t>(reference) |
56 | << " test:" << static_cast<int64_t>(actual) << std::endl; |
57 | } else { |
58 | std::cout << "\t reference:" << reference << " test:" << actual |
59 | << std::endl; |
60 | } |
61 | |
62 | mismatches++; |
63 | if (mismatches > max_mismatches_to_report) { |
64 | return 1; |
65 | } |
66 | } |
67 | } |
68 | } |
69 | return 0; |
70 | } |
71 | |
72 | /** |
73 | * @brief Print the matrix. |
74 | * @param op Transpose type of the matrix. |
75 | * @param R The height of the matrix. |
76 | * @param C The width of the matrix. |
77 | * @param ld The leading dimension of the matrix. |
78 | * @param name The prefix string before printing the matrix. |
79 | */ |
80 | template <typename T> |
81 | void printMatrix( |
82 | matrix_op_t op, |
83 | const T* inp, |
84 | size_t R, |
85 | size_t C, |
86 | size_t ld, |
87 | std::string name) { |
88 | // R: number of rows in op(inp) |
89 | // C: number of cols in op(inp) |
90 | // ld: leading dimension in inp |
91 | std::cout << name << ":" |
92 | << "[" << R << ", " << C << "]" << std::endl; |
93 | bool tr = (op == matrix_op_t::Transpose); |
94 | for (size_t r = 0; r < R; ++r) { |
95 | for (size_t c = 0; c < C; ++c) { |
96 | T res = tr ? inp[c * ld + r] : inp[r * ld + c]; |
97 | if (std::is_integral<T>::value) { |
98 | std::cout << std::setw(5) << static_cast<int64_t>(res) << " " ; |
99 | } else { |
100 | std::cout << std::setw(5) << res << " " ; |
101 | } |
102 | } |
103 | std::cout << std::endl; |
104 | } |
105 | } |
106 | |
107 | template int compare_buffers<float>( |
108 | const float* ref, |
109 | const float* test, |
110 | int m, |
111 | int n, |
112 | int ld, |
113 | size_t max_mismatches_to_report, |
114 | float atol); |
115 | |
116 | template int compare_buffers<int32_t>( |
117 | const int32_t* ref, |
118 | const int32_t* test, |
119 | int m, |
120 | int n, |
121 | int ld, |
122 | size_t max_mismatches_to_report, |
123 | float atol); |
124 | |
125 | template int compare_buffers<uint8_t>( |
126 | const uint8_t* ref, |
127 | const uint8_t* test, |
128 | int m, |
129 | int n, |
130 | int ld, |
131 | size_t max_mismatches_to_report, |
132 | float atol); |
133 | |
134 | template int compare_buffers<int64_t>( |
135 | const int64_t* ref, |
136 | const int64_t* test, |
137 | int m, |
138 | int n, |
139 | int ld, |
140 | size_t max_mismatches_to_report, |
141 | float atol); |
142 | |
143 | template void printMatrix<float>( |
144 | matrix_op_t op, |
145 | const float* inp, |
146 | size_t R, |
147 | size_t C, |
148 | size_t ld, |
149 | std::string name); |
150 | template void printMatrix<int8_t>( |
151 | matrix_op_t op, |
152 | const int8_t* inp, |
153 | size_t R, |
154 | size_t C, |
155 | size_t ld, |
156 | std::string name); |
157 | template void printMatrix<uint8_t>( |
158 | matrix_op_t op, |
159 | const uint8_t* inp, |
160 | size_t R, |
161 | size_t C, |
162 | size_t ld, |
163 | std::string name); |
164 | template void printMatrix<int32_t>( |
165 | matrix_op_t op, |
166 | const int32_t* inp, |
167 | size_t R, |
168 | size_t C, |
169 | size_t ld, |
170 | std::string name); |
171 | |
172 | namespace { |
173 | inst_set_t g_forced_isa = inst_set_t::anyarch; |
174 | bool g_Avx512_Ymm_enabled = false; |
175 | |
176 | inst_set_t fbgemmEnvGetIsa() { |
177 | static const char* isa_env = "FBGEMM_ENABLE_INSTRUCTIONS" ; |
178 | static const std::unordered_map<std::string, inst_set_t> isaMap = { |
179 | {"AVX2" , inst_set_t::avx2}, |
180 | {"AVX512" , inst_set_t::avx512}, |
181 | {"AVX512_E1" , inst_set_t::avx512_vnni}, |
182 | {"AVX512_256" , inst_set_t::avx512_ymm}, |
183 | {"AVX512_E1_256" , inst_set_t::avx512_vnni_ymm}, |
184 | }; |
185 | const char* env = std::getenv(isa_env); |
186 | if (env == nullptr) { |
187 | return inst_set_t::anyarch; |
188 | } |
189 | |
190 | #ifdef __aarch64__ |
191 | #ifdef VLOG |
192 | VLOG(0) << "[" << env << "] not supported on aarch64" ; |
193 | #endif |
194 | return inst_set_t::anyarch; |
195 | #endif |
196 | |
197 | std::string val(env); |
198 | std::transform(val.begin(), val.end(), val.begin(), ::toupper); |
199 | auto it = isaMap.find(val); |
200 | return it == isaMap.end() ? inst_set_t::anyarch : it->second; |
201 | } |
202 | |
203 | bool fbgemmEnvAvx512_256Enabled() { |
204 | static const char* isa_env = "FBGEMM_ENABLE_AVX512_256" ; |
205 | const char* env = std::getenv(isa_env); |
206 | if (env == nullptr) { |
207 | return false; |
208 | } |
209 | |
210 | #ifdef __aarch64__ |
211 | #ifdef VLOG |
212 | VLOG(0) << "[" << env << "] not supported on aarch64" ; |
213 | #endif |
214 | return false; |
215 | #endif |
216 | |
217 | std::string val(env); |
218 | std::transform(val.begin(), val.end(), val.begin(), ::tolower); |
219 | return val == "true" || val == "1" ; |
220 | } |
221 | |
222 | // This is require for build by older compilers GCC 5.4 and C++11 |
223 | struct inst_set_t_hash { |
224 | std::size_t operator()(inst_set_t t) const { |
225 | return static_cast<std::size_t>(t); |
226 | } |
227 | }; |
228 | |
229 | std::unordered_map< |
230 | inst_set_t, |
231 | std::unordered_set<inst_set_t, inst_set_t_hash>, |
232 | inst_set_t_hash> |
233 | isaSupportMap = { |
234 | {inst_set_t::anyarch, {inst_set_t::anyarch}}, |
235 | {inst_set_t::avx2, {inst_set_t::avx2, inst_set_t::anyarch}}, |
236 | {inst_set_t::avx512, |
237 | {inst_set_t::avx512, inst_set_t::avx512_ymm, inst_set_t::avx2}}, |
238 | {inst_set_t::avx512_ymm, |
239 | {inst_set_t::avx512, inst_set_t::avx512_ymm, inst_set_t::avx2}}, |
240 | {inst_set_t::avx512_vnni, |
241 | {inst_set_t::avx512_vnni, |
242 | inst_set_t::avx512_vnni_ymm, |
243 | inst_set_t::avx512, |
244 | inst_set_t::avx512_ymm, |
245 | inst_set_t::avx2}}, |
246 | {inst_set_t::avx512_vnni_ymm, |
247 | {inst_set_t::avx512_vnni, |
248 | inst_set_t::avx512_vnni_ymm, |
249 | inst_set_t::avx512, |
250 | inst_set_t::avx512_ymm, |
251 | inst_set_t::avx2}}, |
252 | }; |
253 | |
254 | } // namespace |
255 | |
256 | /** |
257 | * @brief Force specific architecure to for GEMM kernel execution |
258 | * overides FBGEMM_ENABLE_AVX512_256 env. variable |
259 | * @param isa the ISA to enforce, supported optionsi |
260 | * AVX2 inst_set_t::avx2 |
261 | * AVX512 inst_set_t::avx512 |
262 | * AVX512_E1 inst_set_t::avx512_vnni |
263 | * AVX512_256 inst_set_t::avx512_ymm |
264 | * AVX512_E1_256 inst_set_t::avx512_vnni_ymm |
265 | */ |
266 | void fbgemmForceIsa(inst_set_t isa) { |
267 | g_forced_isa = isa; |
268 | #ifdef __aarch64__ |
269 | #ifdef VLOG |
270 | VLOG(0) << "[anyarch] forced on aarch64" ; |
271 | #endif |
272 | g_forced_isa = inst_set_t::anyarch; |
273 | #endif |
274 | } |
275 | |
276 | /** |
277 | * @brief Enables AVX512-256 if appriate. Inteded for Skylake based Xeon-D |
278 | * processors, wherein AXV512-256 is preferred due to higher |
279 | * Turbo frequencis |
280 | * @param flag True enables / False disables |
281 | */ |
282 | void fbgemmEnableAvx512Ymm(bool flag) { |
283 | g_Avx512_Ymm_enabled = flag; |
284 | } |
285 | |
286 | /** |
287 | * @brief Determine the best available x86 machine ISA to be used for |
288 | * GEMM kernels. |
289 | * FBGEMM_ENABLE_AVX512_256 env. or fbgemmForceIsa() are set |
290 | * forces to specific architecture if supported by the processor. |
291 | * Enforcing on Skylake to AVX2 will execute AVX2 version of the kernel |
292 | * However, enforcing AVX512-256 on Broadwell will fail, and AVX2 version |
293 | * of the kernels will be executed. |
294 | */ |
295 | inst_set_t fbgemmInstructionSet() { |
296 | static const inst_set_t env_forced_isa = fbgemmEnvGetIsa(); |
297 | static const bool isAvx512_Ymm_enabled = fbgemmEnvAvx512_256Enabled(); |
298 | |
299 | inst_set_t forced_isa = |
300 | g_forced_isa != inst_set_t::anyarch ? g_forced_isa : env_forced_isa; |
301 | static const inst_set_t detected_isa = ([]() { |
302 | inst_set_t isa = inst_set_t::anyarch; |
303 | // Check environment |
304 | if (cpuinfo_initialize()) { |
305 | const bool isXeonD = fbgemmIsIntelXeonD() && |
306 | (g_Avx512_Ymm_enabled || isAvx512_Ymm_enabled); |
307 | if (fbgemmHasAvx512VnniSupport()) { |
308 | if (isXeonD) { |
309 | isa = inst_set_t::avx512_vnni_ymm; |
310 | } else { |
311 | isa = inst_set_t::avx512_vnni; |
312 | } |
313 | } else if (fbgemmHasAvx512Support()) { |
314 | if (isXeonD) { |
315 | isa = inst_set_t::avx512_ymm; |
316 | } else { |
317 | isa = inst_set_t::avx512; |
318 | } |
319 | } else if (fbgemmHasAvx2Support()) { |
320 | isa = inst_set_t::avx2; |
321 | } |
322 | } |
323 | return isa; |
324 | })(); |
325 | |
326 | if (forced_isa == inst_set_t::anyarch) { |
327 | return detected_isa; |
328 | } |
329 | const auto supported_isa = isaSupportMap.find(detected_isa); |
330 | assert( |
331 | supported_isa != isaSupportMap.end() && |
332 | "Detected ISA can't be located in Supported ISA map" ); |
333 | if (supported_isa == isaSupportMap.end()) { |
334 | return detected_isa; |
335 | } |
336 | return supported_isa->second.count(forced_isa) ? forced_isa : detected_isa; |
337 | } |
338 | |
339 | bool isZmm(inst_set_t isa) { |
340 | return isa == inst_set_t::avx512 || isa == inst_set_t::avx512_vnni; |
341 | } |
342 | |
343 | bool isYmm(inst_set_t isa) { |
344 | return isa == inst_set_t::avx512_ymm || isa == inst_set_t::avx512_vnni_ymm || |
345 | isa == inst_set_t::avx2; |
346 | } |
347 | |
348 | bool fbgemmIsIntelXeonD() { |
349 | auto const pkgInfo = cpuinfo_get_packages(); |
350 | if (strstr(pkgInfo->name, "Intel Xeon D-" ) || |
351 | cpuinfo_get_packages_count() == 1) { |
352 | return true; |
353 | } |
354 | return false; |
355 | } |
356 | |
357 | bool fbgemmHasAvx512Support() { |
358 | return ( |
359 | cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512bw() && |
360 | cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512vl()); |
361 | } |
362 | |
363 | bool fbgemmHasAvx2Support() { |
364 | return (cpuinfo_has_x86_avx2()); |
365 | } |
366 | |
367 | bool fbgemmHasAvx512VnniSupport() { |
368 | return (cpuinfo_has_x86_avx512vnni()); |
369 | } |
370 | |
371 | bool fbgemmHasArmNeonSupport() { |
372 | return (cpuinfo_has_arm_neon()); |
373 | } |
374 | |
375 | void fbgemmPartition1D( |
376 | int thread_id, |
377 | int num_threads, |
378 | int64_t total_work, |
379 | int64_t& start, |
380 | int64_t& end) { |
381 | // if num_threads == 0, |
382 | // this threads should not perform any work |
383 | if (num_threads == 0) { |
384 | start = end = 0; |
385 | return; |
386 | } |
387 | int64_t work_per_thread = (total_work + num_threads - 1) / num_threads; |
388 | start = std::min(thread_id * work_per_thread, total_work); |
389 | end = std::min((thread_id + 1) * work_per_thread, total_work); |
390 | } |
391 | |
392 | void fbgemmPartition1DBlocked( |
393 | int thread_id, |
394 | int num_threads, |
395 | int64_t total_work, |
396 | int block_size, |
397 | int64_t& start, |
398 | int64_t& end) { |
399 | if (block_size == 1) { |
400 | return fbgemmPartition1D(thread_id, num_threads, total_work, start, end); |
401 | } |
402 | int64_t total_work_in_blocks = total_work / block_size; |
403 | int64_t start_block, end_block; |
404 | fbgemmPartition1D( |
405 | thread_id, num_threads, total_work_in_blocks, start_block, end_block); |
406 | start = std::min(start_block * block_size, total_work); |
407 | end = thread_id == num_threads - 1 |
408 | ? std::max(end_block * block_size, total_work) |
409 | : std::min(end_block * block_size, total_work); |
410 | } |
411 | |
412 | void* fbgemmAlignedAlloc( |
413 | size_t align, |
414 | size_t size, |
415 | bool raiseException /*=false*/) { |
416 | void* aligned_mem = nullptr; |
417 | int ret; |
418 | #ifdef _MSC_VER |
419 | aligned_mem = _aligned_malloc(size, align); |
420 | ret = 0; |
421 | #else |
422 | ret = posix_memalign(&aligned_mem, align, size); |
423 | #endif |
424 | // Throw std::bad_alloc in the case of memory allocation failure. |
425 | if (raiseException || ret || aligned_mem == nullptr) { |
426 | throw std::bad_alloc(); |
427 | } |
428 | return aligned_mem; |
429 | } |
430 | |
431 | void fbgemmAlignedFree(void* p) { |
432 | #ifdef _MSC_VER |
433 | _aligned_free(p); |
434 | #else |
435 | free(p); |
436 | #endif |
437 | } |
438 | |
439 | int fbgemmGet2DPartition( |
440 | int m, |
441 | int n, |
442 | int nthreads, |
443 | int n_align, |
444 | double aspect_ratio) { |
445 | // mb: number of thread blocks within a socket along m. |
446 | // nb: number of thread blocks along n. |
447 | // mb * nb = nthreads. |
448 | // bm: number of rows assigned per thread block (bm = ceil(m/mb)). |
449 | // bn: number of cols assigned per thread block (bn = ceil(n/nb)). |
450 | // find mb and nb such that bm / bn is as close as possible to aspect_ratio. |
451 | |
452 | // for large thread numbers, we would like to reduce the aspect_ratio --- |
453 | // if the matrix is short-and-fat |
454 | // this allows us to assign more parallelism to i-dimension |
455 | if (nthreads > 16 && m / n < 0.2) { |
456 | aspect_ratio = 0.2; |
457 | } |
458 | |
459 | int mb = 1; |
460 | int nb = nthreads / mb; |
461 | int bm = (m + mb - 1) / mb; |
462 | int bn = ((n + n_align - 1) / n_align + nb - 1) / nb * n_align; |
463 | double best_delta = std::abs(static_cast<double>(bm) / bn - aspect_ratio); |
464 | for (int mb_candidate = 2; mb_candidate <= nthreads; mb_candidate++) { |
465 | // mb does not need to divide nthreads |
466 | // here nthreads % mb_candidate!=0 constraint is removed for nthreads>16 |
467 | if (nthreads % mb_candidate != 0 && nthreads <= 16) { |
468 | continue; |
469 | } |
470 | int nb_candidate = nthreads / mb_candidate; |
471 | int bm_candidate = (m + mb_candidate - 1) / mb_candidate; |
472 | int bn_candidate = ((n + n_align - 1) / n_align + nb_candidate - 1) / |
473 | nb_candidate * n_align; |
474 | double delta = std::abs( |
475 | static_cast<double>(bm_candidate) / bn_candidate - aspect_ratio); |
476 | if (delta < best_delta) { |
477 | best_delta = delta; |
478 | mb = mb_candidate; |
479 | } else { |
480 | break; |
481 | } |
482 | } |
483 | return mb; |
484 | } |
485 | |
486 | thread_type_t fbgemmGetThreadPartition( |
487 | int g, |
488 | int m, |
489 | int n, |
490 | int thread_id, |
491 | int num_threads, |
492 | int n_align) { |
493 | assert(num_threads >= 1); |
494 | |
495 | // Fast path for the single thread case. |
496 | if (num_threads == 1) { |
497 | return thread_type_t{1, 1, 1, 0, 0, 0}; |
498 | } |
499 | |
500 | thread_type_t th_info; |
501 | |
502 | // Heuristic for determine the thread partitions for parallelizing across g, m |
503 | // or n dimensions. |
504 | // TODO: more smart ways for thread partitions considering the |
505 | // grain size (MR, NR) parameters |
506 | if (g > num_threads) { |
507 | // TODO: when G == nthreads + 1, we'll have a big load imbalance because |
508 | // only one thread will get 2 groups. |
509 | th_info.g_num_threads = num_threads; |
510 | } else { |
511 | if (g != 0 && num_threads % g == 0) { |
512 | th_info.g_num_threads = g; |
513 | } else { |
514 | th_info.g_num_threads = 1; |
515 | } |
516 | } |
517 | num_threads /= th_info.g_num_threads; |
518 | |
519 | // We favor the parallelization on the m dimension compared to the n |
520 | // dimension, so we set aspect_ratio to 0.5 here. |
521 | th_info.m_num_threads = fbgemmGet2DPartition(m, n, num_threads, n_align, 0.5); |
522 | |
523 | // when num_threads >16, m_num_threads may not divide num_threads |
524 | if (num_threads <= 16) { |
525 | assert(num_threads % (th_info.m_num_threads) == 0); |
526 | } |
527 | th_info.n_num_threads = num_threads / th_info.m_num_threads; |
528 | |
529 | // When there are 12 threads (num_threads = 12) and g_nthreads = 2, m_nthreads |
530 | // = 2, the threads will be organized as the following 2x2x3 layout (thread is |
531 | // partitioned in the last-dim index (i.e., n, m, g, row-major for 2D) major |
532 | // order): |
533 | // |
534 | // thread 0, thread 1, thread 2 thread 6, thread 7, thread 8 |
535 | // thread 3, thread 4, thread 5 thread 9, thread 10, thread 11 |
536 | // |
537 | // And the corresponding (g_thread_id, m_thread_id, n_thread_id) for |
538 | // each thread is listed as the following: |
539 | // |
540 | // (0, 0, 0), (0, 0, 1), (0, 0, 2) (1, 0, 0), (1, 0, 1), (1, 0, 2) |
541 | // (0, 1, 0), (0, 1, 1), (0, 1, 2) (1, 1, 0), (1, 1, 1), (1, 1, 2) |
542 | |
543 | // thread can be inactive, |
544 | // meaning they are launched, but will not be assigned any work |
545 | if (thread_id >= |
546 | th_info.g_num_threads * th_info.m_num_threads * th_info.n_num_threads) { |
547 | th_info.m_thread_id = 0; |
548 | th_info.n_thread_id = 0; |
549 | th_info.g_thread_id = 0; |
550 | th_info.m_num_threads = 0; |
551 | th_info.n_num_threads = 0; |
552 | th_info.g_num_threads = 0; |
553 | return th_info; |
554 | } |
555 | |
556 | // We can view the thread as the ternary with 3-dim base: {g,m,n}_num_threads. |
557 | th_info.n_thread_id = thread_id % th_info.n_num_threads; |
558 | thread_id /= th_info.n_num_threads; |
559 | th_info.m_thread_id = thread_id % th_info.m_num_threads; |
560 | thread_id /= th_info.m_num_threads; |
561 | th_info.g_thread_id = thread_id % th_info.g_num_threads; |
562 | |
563 | return th_info; |
564 | } |
565 | |
566 | } // namespace fbgemm |
567 | |