1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_BRGEMM_BRGEMM_TYPES_HPP |
18 | #define CPU_X64_BRGEMM_BRGEMM_TYPES_HPP |
19 | |
20 | #include "common/primitive_attr.hpp" |
21 | #include "cpu/platform.hpp" |
22 | #include "cpu/x64/cpu_isa_traits.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | |
29 | // The type defines organization of batch of matrices |
30 | typedef enum { |
31 | // A and B arrays of pointers |
32 | brgemm_addr = 1, |
33 | // Base address and array of offsets from base address. |
34 | brgemm_offs = 2, |
35 | // Base address and fixed stride between matrices. |
36 | brgemm_strd = 3, |
37 | } brgemm_batch_kind_t; |
38 | |
39 | // The type defines the storage format of matrix |
40 | typedef enum { |
41 | brgemm_layout_undef = 0, |
42 | brgemm_col_major = 1, |
43 | brgemm_row_major = 2, |
44 | } brgemm_layout_t; |
45 | |
46 | typedef enum { |
47 | none = 0, |
48 | per_tensor = 1, |
49 | per_m = 2, |
50 | per_n = 3, |
51 | per_k = 4, |
52 | } brgemm_broadcast_t; |
53 | |
54 | struct brgemm_strides_t { |
55 | // Stride between A matrices |
56 | dim_t stride_a; |
57 | // Stride between B matrices |
58 | dim_t stride_b; |
59 | }; |
60 | |
61 | typedef enum { |
62 | brgemm_lo_default = 0, |
63 | brgemm_lo_bl_1load, |
64 | brgemm_lo_bl_1bcst, |
65 | } brgemm_kernel_loop_order_t; |
66 | |
67 | typedef enum { |
68 | brgemm_prf_default = 1, |
69 | brgemm_prf1 = 2, |
70 | brgemm_prf2 = 3, |
71 | } brgemm_kernel_prefetching_t; |
72 | |
73 | typedef enum { |
74 | brgemm_bd_loop_innermost = 0, |
75 | brgemm_ld_loop_innermost, |
76 | } brgemm_kernel_innermost_loop_t; |
77 | |
78 | typedef enum { |
79 | brgemm_hint_nt_undef = -1, |
80 | brgemm_hint_nt_false = 0, |
81 | brgemm_hint_nt_true = 1, |
82 | } brgemm_kernel_hint_nt_t; |
83 | |
84 | struct brgemm_prf_t { |
85 | int dist1 = -1; |
86 | int dist2 = -1; |
87 | }; |
88 | |
89 | struct DNNL_API brgemm_attr_t { |
90 | brgemm_attr_t(); |
91 | // if unrolled kernel is used (use_uker == true) |
92 | // then "max_bs" is the the only batch size that can be used on kernel call |
93 | // else "max_bs" is the maximum batch size that can be used |
94 | int max_bs; |
95 | int max_top_vpad, max_bottom_vpad; |
96 | dim_t hint_expected_A_size, hint_expected_B_size, hint_expected_C_size; |
97 | brgemm_kernel_innermost_loop_t hint_innermost_loop |
98 | = brgemm_ld_loop_innermost; |
99 | brgemm_kernel_loop_order_t hint_loop_order; |
100 | brgemm_kernel_prefetching_t hint_prefetching |
101 | = brgemm_kernel_prefetching_t::brgemm_prf_default; |
102 | brgemm_prf_t hint_prfA, hint_prfB, hint_prfC; |
103 | |
104 | bool wary_tail_read; |
105 | bool generate_skip_accumulation; |
106 | // bd_mask is char array in which each element is a boolean value that |
107 | // determines whether to write this row to the result matrix or skip |
108 | char *bd_mask; |
109 | // Value of bd_mask_level specifies how bd_mask is used in brgemm kernel |
110 | // 0 – bd_mask is not used |
111 | // 1 – bd_mask is used on storing stage only |
112 | // 2 – bd_mask used both on reading and storing stages |
113 | int bd_mask_level; |
114 | // use_uker is a boolean value that determines whether to use the unrolled |
115 | // kernel or not |
116 | bool use_uker; |
117 | // use_interleave_stores is a value that determines whether to use the |
118 | // interleave stores or not |
119 | bool use_interleave_stores; |
120 | impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict; |
121 | // Second level leading dimension describing distance between 16-line |
122 | // blocks in case of blocked layout. Used to calculate address of next |
123 | // bd block. By default are equal to regular leading dimension parameters |
124 | // specified on brgemm creation. |
125 | // Supported by brgemm unrolled kernel for now. |
126 | int LDA2 {0}, LDB2 {0}, LDC2_M {0}, LDC2_N {0}; |
127 | // If "true" then batchsize is allowed to change on each kernel call |
128 | // and there is no unrolling by batchsize in kernel |
129 | bool var_bs {false}; |
130 | bool postops_only {false}; |
131 | |
132 | int hint_bd_block {0}; |
133 | int hint_ld_block {0}; |
134 | int hint_bd_block2 {0}; |
135 | int hint_ld_block2 {0}; |
136 | |
137 | brgemm_kernel_hint_nt_t hint_load_nt_A {brgemm_hint_nt_undef}; |
138 | brgemm_kernel_hint_nt_t hint_load_nt_B {brgemm_hint_nt_undef}; |
139 | }; |
140 | |
141 | struct brgemm_batch_element_t { |
142 | brgemm_batch_element_t() { |
143 | ptr.A = ptr.B = nullptr; |
144 | vvpad.top = vvpad.bottom = 0; |
145 | } |
146 | union { |
147 | struct { |
148 | const void *A; |
149 | const void *B; |
150 | } ptr; |
151 | struct { |
152 | dim_t A; |
153 | dim_t B; |
154 | } offset; |
155 | }; |
156 | union { |
157 | struct { |
158 | dim_t top; |
159 | dim_t bottom; |
160 | } vvpad; |
161 | struct { |
162 | dim_t left; |
163 | dim_t right; |
164 | } hvpad; |
165 | }; |
166 | }; |
167 | |
168 | struct brgemm_t { |
169 | int bcast_dim = 0; // M; |
170 | int load_dim = 0; // N; |
171 | int reduce_dim = 0; // K; |
172 | int LDA = 0; |
173 | int LDB = 0; |
174 | int LDC = 0; |
175 | int LDD = 0; |
176 | // we use two isa_ variables |
177 | // isa_user to store the user provided isa value |
178 | // isa_impl to store actual implementation. This can change until the kernel |
179 | // is created, as calls to set_attr can affect this variable. Ex: bf32 |
180 | impl::cpu::x64::cpu_isa_t isa_user = isa_undef; |
181 | impl::cpu::x64::cpu_isa_t isa_impl = isa_undef; |
182 | |
183 | int LDA2 {0}, LDB2 {0}, LDC2_M {0}, LDC2_N {0}; |
184 | bool is_blocked = false; |
185 | |
186 | float alpha = 0.0f; |
187 | float beta = 0.0f; |
188 | |
189 | int bdb = 0, bd_block = 0, bdb_tail = 0; |
190 | int bdb2 = 0, bd_block2 = 0, bdb2_tail = 0; |
191 | int ldb = 0, ld_block = 0, ldb_tail = 0; |
192 | int ldb2 = 0, ld_block2 = 0, ldb2_tail = 0; |
193 | int rdb = 0, rd_block = 0, rdb_tail = 0; |
194 | int rd_step = 0, ld_step = 0; |
195 | |
196 | impl::data_type_t dt_a = data_type::undef; |
197 | impl::data_type_t dt_c = data_type::undef; |
198 | impl::data_type_t dt_b = data_type::undef; |
199 | impl::data_type_t dt_d = data_type::undef; |
200 | impl::data_type_t dt_bias = data_type::undef; |
201 | |
202 | int typesize_A = 0; |
203 | int typesize_B = 0; |
204 | int typesize_C = 0; |
205 | int typesize_D = 0; |
206 | int typesize_bias = 0; |
207 | |
208 | bool is_ymm = false; |
209 | bool is_zmm = false; |
210 | bool is_tmm = false; |
211 | bool is_int8 = false, is_int8_tmm = false; |
212 | bool is_bf16 = false, is_bf16_tmm = false, is_bf16_emu = false; |
213 | bool is_f16 = false, is_f16_tmm = false; |
214 | bool is_f32 = false; |
215 | bool is_bf32 = false; |
216 | |
217 | dim_t stride_a = 0; // Offset in bytes |
218 | dim_t stride_b = 0; |
219 | |
220 | brgemm_layout_t layout = brgemm_layout_undef; |
221 | brgemm_batch_kind_t type; |
222 | |
223 | bool load_nt_A = false; |
224 | bool load_nt_B = false; |
225 | bool embd_bcst = false; |
226 | bool is_dgmm = false; // set to true in brdgmm_desc_init |
227 | bool with_bias = false; |
228 | bool with_sum = false; |
229 | float sum_scale = 0.0f; |
230 | int32_t sum_zp = 0; |
231 | impl::data_type_t sum_dt; |
232 | bool with_eltwise = false; |
233 | bool with_binary = false; |
234 | bool with_scales = false; |
235 | bool req_cal_comp_pads = false; |
236 | bool req_s8s8_compensation = false; |
237 | brgemm_broadcast_t zp_type_a = brgemm_broadcast_t::none; |
238 | brgemm_broadcast_t zp_type_b = brgemm_broadcast_t::none; |
239 | brgemm_broadcast_t zp_type_c = brgemm_broadcast_t::none; |
240 | |
241 | int is_oc_scale = 0; |
242 | |
243 | const primitive_attr_t *attr = nullptr; |
244 | const memory_desc_t *dst_md = nullptr; |
245 | |
246 | brgemm_attr_t brgattr; |
247 | static constexpr int MAX_VPAD = 100; |
248 | static constexpr int AMX_TILES_NUM = 8; |
249 | |
250 | int is_M_tail; |
251 | |
252 | bool interleave_tilestores_ = false; |
253 | |
254 | brgemm_prf_t prfA, prfB, prfC; |
255 | |
256 | bool is_row_major() const { |
257 | assert(layout != brgemm_layout_undef); |
258 | return layout == brgemm_row_major; |
259 | } |
260 | |
261 | // Tile register decomposition |
262 | int get_bd_block2() const noexcept { |
263 | return (bdb_tail) ? bd_block2 + 1 : bd_block2; |
264 | } |
265 | int get_ld_block2() const noexcept { |
266 | return (ldb_tail) ? ld_block2 + 1 : ld_block2; |
267 | } |
268 | int get_num_C_tiles() const noexcept { |
269 | return get_bd_block2() * get_ld_block2(); |
270 | } |
271 | int get_C_tensor(int m, int n, bool m_tail = false, |
272 | bool n_tail = false) const noexcept { |
273 | auto M = m_tail ? get_bd_block2() - 1 : m; |
274 | auto N = n_tail ? get_ld_block2() - 1 : n; |
275 | return (M * get_ld_block2() + N); |
276 | } |
277 | |
278 | int tiles_for_A() const noexcept { |
279 | return (AMX_TILES_NUM - get_num_C_tiles() - 1); |
280 | } |
281 | |
282 | int get_A_tensor(int m, bool m_tail = false) const noexcept { |
283 | auto full_A_tiles = get_num_A_tiles() - (bdb_tail ? 1 : 0); |
284 | auto M = m_tail ? get_num_A_tiles() - 1 : m % full_A_tiles; |
285 | return (get_num_C_tiles() + M); |
286 | } |
287 | |
288 | int get_num_A_tiles() const noexcept { |
289 | return nstl::min(get_bd_block2(), tiles_for_A()); |
290 | } |
291 | |
292 | int tiles_for_B() const noexcept { |
293 | return (AMX_TILES_NUM - get_num_C_tiles() - get_num_A_tiles()); |
294 | } |
295 | |
296 | int get_B_tensor(int n, bool n_tail = false) const noexcept { |
297 | auto full_B_tiles = get_num_B_tiles() - (ldb_tail ? 1 : 0); |
298 | auto N = n_tail ? get_num_B_tiles() - 1 : n % full_B_tiles; |
299 | return (get_num_C_tiles() + get_num_A_tiles() + N); |
300 | } |
301 | |
302 | int get_num_B_tiles() const noexcept { |
303 | return nstl::min(get_ld_block2(), tiles_for_B()); |
304 | } |
305 | |
306 | int get_wsp_buffer_size() const noexcept { |
307 | int sz = 0; |
308 | if (is_tmm) { |
309 | constexpr int tilesize = 1024; |
310 | sz = get_num_C_tiles() * tilesize; // postops buffer |
311 | if (is_bf32) { |
312 | const int n_bdb = bd_block2; |
313 | const int n_rdb = rdb + (rdb_tail != 0); |
314 | const int n_ldb = ldb + (ldb_tail != 0); |
315 | const int downcvt_tiles |
316 | = brgattr.max_bs * n_rdb * (n_bdb + n_ldb); |
317 | sz += downcvt_tiles * tilesize; |
318 | } |
319 | } |
320 | return sz; |
321 | } |
322 | |
323 | bool is_b_data_layout_vnni() { |
324 | // True in general, only exception is f16 with avx512_core_fp16. |
325 | // We also return `true` for bf32 (brgattr.fpmath_mode_ = bf16), |
326 | // because the data transformation to vnni layout is internal |
327 | // and transparent to user. |
328 | return !(dt_b == data_type::f16 && isa_impl == avx512_core_fp16); |
329 | } |
330 | }; |
331 | |
332 | struct brgemm_kernel_params_t { |
333 | const void *ptr_A; |
334 | const void *ptr_B; |
335 | const brgemm_batch_element_t *batch; |
336 | void *ptr_C; |
337 | |
338 | const void *ptr_bias; |
339 | void *ptr_D; |
340 | |
341 | /* kernel takes single pointer scales, but configuration relies on a |
342 | * combination of arg scales. This helps to reuse attributes from |
343 | * primitives, but requires them to pre-compute |
344 | * scales = src_scale * wei_scale[:] |
345 | */ |
346 | const void *ptr_scales; |
347 | void *ptr_buf; |
348 | |
349 | size_t do_post_ops; |
350 | size_t do_apply_comp; |
351 | size_t BS; |
352 | |
353 | /* |
354 | * ptr to table of void * elements that are pointers to post_op binary |
355 | * src1 tensors |
356 | */ |
357 | const void *post_ops_binary_rhs_arg_vec; |
358 | size_t oc_logical_off; |
359 | size_t first_mb_matrix_addr_off; |
360 | size_t dst_row_logical_off; |
361 | |
362 | char *data_C_ptr_; |
363 | |
364 | const void *a_zp_compensations = nullptr; |
365 | const void *b_zp_compensations = nullptr; |
366 | const void *c_zp_values = nullptr; |
367 | size_t skip_accm = 0; |
368 | int32_t zp_a_val = 1; |
369 | }; |
370 | |
371 | template <cpu_isa_t isa, typename Vmm> |
372 | struct jit_brgemm_kernel_t; |
373 | struct jit_brgemm_amx_uker_base_t; |
374 | template <cpu_isa_t isa, typename Vmm> |
375 | struct jit_brdgmm_kernel_base_t; |
376 | |
377 | struct brgemm_kernel_t { |
378 | brgemm_kernel_t() {}; |
379 | virtual ~brgemm_kernel_t() {}; |
380 | virtual status_t create_kernel() = 0; |
381 | virtual void operator()(brgemm_kernel_params_t *) const = 0; |
382 | }; |
383 | |
384 | template <cpu_isa_t isa, typename Vmm> |
385 | struct brgemm_kernel_common_t : public brgemm_kernel_t { |
386 | brgemm_kernel_common_t(const brgemm_t abrd); |
387 | ~brgemm_kernel_common_t(); |
388 | |
389 | status_t create_kernel(); |
390 | void operator()(brgemm_kernel_params_t *) const; |
391 | |
392 | private: |
393 | jit_brgemm_kernel_t<isa, Vmm> *brgemm_kernel_ = nullptr; |
394 | |
395 | DNNL_DISALLOW_COPY_AND_ASSIGN(brgemm_kernel_common_t); |
396 | }; |
397 | |
398 | struct brgemm_amx_uker_t : public brgemm_kernel_t { |
399 | brgemm_amx_uker_t(const brgemm_t abrd); |
400 | ~brgemm_amx_uker_t(); |
401 | |
402 | status_t create_kernel(); |
403 | void operator()(brgemm_kernel_params_t *) const; |
404 | |
405 | private: |
406 | jit_brgemm_amx_uker_base_t *brgemm_kernel_ = nullptr; |
407 | |
408 | DNNL_DISALLOW_COPY_AND_ASSIGN(brgemm_amx_uker_t); |
409 | }; |
410 | |
411 | template <cpu_isa_t isa, typename Vmm> |
412 | struct brdgmm_kernel_t : public brgemm_kernel_t { |
413 | brdgmm_kernel_t(const brgemm_t abrd); |
414 | ~brdgmm_kernel_t(); |
415 | |
416 | status_t create_kernel(); |
417 | void operator()(brgemm_kernel_params_t *) const; |
418 | |
419 | private: |
420 | jit_brdgmm_kernel_base_t<isa, Vmm> *brgemm_kernel_ = nullptr; |
421 | |
422 | DNNL_DISALLOW_COPY_AND_ASSIGN(brdgmm_kernel_t); |
423 | }; |
424 | |
425 | /// @param bias Vector of bias (vector length is N) |
426 | /// @param scales Vector of scales (vector length is N) |
427 | /// @param binary_post_ops_rhs - Ptr to table of pointers to tensors used as rhs |
428 | /// in binary post-operation { void* binary_op_tensor1, ..., |
429 | /// void* binary_op_tensor_n} |
430 | /// @param oc_logical_off - Used in binary postops in per_oc bcast strategy. |
431 | /// Offset to start oc processed by given thread in elements. |
432 | /// @param dst_row_logical_off - Used in binary postops in per_oc bcast |
433 | /// strategy. Offset to start oc processed by given thread in elements. |
434 | /// @param a_zp_compensations - Pre-computed compensations for A matrix zero |
435 | /// point values. |
436 | /// @param b_zp_compensations - Pre-computed compensations for B matrix zero |
437 | /// point values. |
438 | /// @param c_zp_values - C matrix zero point values. |
439 | /// @param skip_accumulation - specifies whether to skip accumulation when |
440 | /// computing post-ops. `Beta` value from descriptor affects final |
441 | /// accumulator values taken. |
442 | /// |
443 | struct brgemm_post_ops_data_t { |
444 | brgemm_post_ops_data_t() = default; |
445 | brgemm_post_ops_data_t(const void *bias, const float *scales, |
446 | const void *binary_post_ops_rhs, size_t oc_logical_off, |
447 | const size_t dst_row_logical_off = 0, char *data_C_ptr_ = nullptr, |
448 | const size_t first_mb_matrix_addr_off = 0, |
449 | const void *a_zp_compensations = nullptr, |
450 | const void *b_zp_compensations = nullptr, |
451 | const void *c_zp_values = nullptr, bool skip_accumulation = false, |
452 | int32_t zp_a_val = 1, bool do_only_comp = false, |
453 | bool do_only_zp_a_val = false) |
454 | : bias(bias) |
455 | , scales(scales) |
456 | , binary_post_ops_rhs(binary_post_ops_rhs) |
457 | , oc_logical_off(oc_logical_off) |
458 | , dst_row_logical_off(dst_row_logical_off) |
459 | , data_C_ptr_(data_C_ptr_) |
460 | , first_mb_matrix_addr_off(first_mb_matrix_addr_off) |
461 | , a_zp_compensations(a_zp_compensations) |
462 | , b_zp_compensations(b_zp_compensations) |
463 | , c_zp_values(c_zp_values) |
464 | , skip_accumulation(skip_accumulation) |
465 | , zp_a_val {zp_a_val} |
466 | , do_only_comp {do_only_comp} |
467 | , do_only_zp_a_val {do_only_zp_a_val} {} |
468 | |
469 | const void *bias = nullptr; |
470 | const float *scales = nullptr; |
471 | const void *binary_post_ops_rhs = nullptr; |
472 | size_t oc_logical_off = 0; |
473 | size_t dst_row_logical_off = 0; |
474 | char *data_C_ptr_ = nullptr; |
475 | size_t first_mb_matrix_addr_off = 0; |
476 | const void *a_zp_compensations = nullptr; |
477 | const void *b_zp_compensations = nullptr; |
478 | const void *c_zp_values = nullptr; |
479 | const bool skip_accumulation = false; |
480 | int32_t zp_a_val = 1; |
481 | const bool do_only_comp = false; |
482 | const bool do_only_zp_a_val = false; |
483 | }; |
484 | |
485 | } // namespace x64 |
486 | } // namespace cpu |
487 | } // namespace impl |
488 | } // namespace dnnl |
489 | |
490 | #endif |
491 | |
492 | //vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
493 | |