1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_BRGEMM_BRGEMM_TYPES_HPP
18#define CPU_X64_BRGEMM_BRGEMM_TYPES_HPP
19
20#include "common/primitive_attr.hpp"
21#include "cpu/platform.hpp"
22#include "cpu/x64/cpu_isa_traits.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28
29// The type defines organization of batch of matrices
30typedef enum {
31 // A and B arrays of pointers
32 brgemm_addr = 1,
33 // Base address and array of offsets from base address.
34 brgemm_offs = 2,
35 // Base address and fixed stride between matrices.
36 brgemm_strd = 3,
37} brgemm_batch_kind_t;
38
39// The type defines the storage format of matrix
40typedef enum {
41 brgemm_layout_undef = 0,
42 brgemm_col_major = 1,
43 brgemm_row_major = 2,
44} brgemm_layout_t;
45
46typedef enum {
47 none = 0,
48 per_tensor = 1,
49 per_m = 2,
50 per_n = 3,
51 per_k = 4,
52} brgemm_broadcast_t;
53
54struct brgemm_strides_t {
55 // Stride between A matrices
56 dim_t stride_a;
57 // Stride between B matrices
58 dim_t stride_b;
59};
60
61typedef enum {
62 brgemm_lo_default = 0,
63 brgemm_lo_bl_1load,
64 brgemm_lo_bl_1bcst,
65} brgemm_kernel_loop_order_t;
66
67typedef enum {
68 brgemm_prf_default = 1,
69 brgemm_prf1 = 2,
70 brgemm_prf2 = 3,
71} brgemm_kernel_prefetching_t;
72
73typedef enum {
74 brgemm_bd_loop_innermost = 0,
75 brgemm_ld_loop_innermost,
76} brgemm_kernel_innermost_loop_t;
77
78typedef enum {
79 brgemm_hint_nt_undef = -1,
80 brgemm_hint_nt_false = 0,
81 brgemm_hint_nt_true = 1,
82} brgemm_kernel_hint_nt_t;
83
84struct brgemm_prf_t {
85 int dist1 = -1;
86 int dist2 = -1;
87};
88
89struct DNNL_API brgemm_attr_t {
90 brgemm_attr_t();
91 // if unrolled kernel is used (use_uker == true)
92 // then "max_bs" is the the only batch size that can be used on kernel call
93 // else "max_bs" is the maximum batch size that can be used
94 int max_bs;
95 int max_top_vpad, max_bottom_vpad;
96 dim_t hint_expected_A_size, hint_expected_B_size, hint_expected_C_size;
97 brgemm_kernel_innermost_loop_t hint_innermost_loop
98 = brgemm_ld_loop_innermost;
99 brgemm_kernel_loop_order_t hint_loop_order;
100 brgemm_kernel_prefetching_t hint_prefetching
101 = brgemm_kernel_prefetching_t::brgemm_prf_default;
102 brgemm_prf_t hint_prfA, hint_prfB, hint_prfC;
103
104 bool wary_tail_read;
105 bool generate_skip_accumulation;
106 // bd_mask is char array in which each element is a boolean value that
107 // determines whether to write this row to the result matrix or skip
108 char *bd_mask;
109 // Value of bd_mask_level specifies how bd_mask is used in brgemm kernel
110 // 0 – bd_mask is not used
111 // 1 – bd_mask is used on storing stage only
112 // 2 – bd_mask used both on reading and storing stages
113 int bd_mask_level;
114 // use_uker is a boolean value that determines whether to use the unrolled
115 // kernel or not
116 bool use_uker;
117 // use_interleave_stores is a value that determines whether to use the
118 // interleave stores or not
119 bool use_interleave_stores;
120 impl::fpmath_mode_t fpmath_mode = fpmath_mode::strict;
121 // Second level leading dimension describing distance between 16-line
122 // blocks in case of blocked layout. Used to calculate address of next
123 // bd block. By default are equal to regular leading dimension parameters
124 // specified on brgemm creation.
125 // Supported by brgemm unrolled kernel for now.
126 int LDA2 {0}, LDB2 {0}, LDC2_M {0}, LDC2_N {0};
127 // If "true" then batchsize is allowed to change on each kernel call
128 // and there is no unrolling by batchsize in kernel
129 bool var_bs {false};
130 bool postops_only {false};
131
132 int hint_bd_block {0};
133 int hint_ld_block {0};
134 int hint_bd_block2 {0};
135 int hint_ld_block2 {0};
136
137 brgemm_kernel_hint_nt_t hint_load_nt_A {brgemm_hint_nt_undef};
138 brgemm_kernel_hint_nt_t hint_load_nt_B {brgemm_hint_nt_undef};
139};
140
141struct brgemm_batch_element_t {
142 brgemm_batch_element_t() {
143 ptr.A = ptr.B = nullptr;
144 vvpad.top = vvpad.bottom = 0;
145 }
146 union {
147 struct {
148 const void *A;
149 const void *B;
150 } ptr;
151 struct {
152 dim_t A;
153 dim_t B;
154 } offset;
155 };
156 union {
157 struct {
158 dim_t top;
159 dim_t bottom;
160 } vvpad;
161 struct {
162 dim_t left;
163 dim_t right;
164 } hvpad;
165 };
166};
167
168struct brgemm_t {
169 int bcast_dim = 0; // M;
170 int load_dim = 0; // N;
171 int reduce_dim = 0; // K;
172 int LDA = 0;
173 int LDB = 0;
174 int LDC = 0;
175 int LDD = 0;
176 // we use two isa_ variables
177 // isa_user to store the user provided isa value
178 // isa_impl to store actual implementation. This can change until the kernel
179 // is created, as calls to set_attr can affect this variable. Ex: bf32
180 impl::cpu::x64::cpu_isa_t isa_user = isa_undef;
181 impl::cpu::x64::cpu_isa_t isa_impl = isa_undef;
182
183 int LDA2 {0}, LDB2 {0}, LDC2_M {0}, LDC2_N {0};
184 bool is_blocked = false;
185
186 float alpha = 0.0f;
187 float beta = 0.0f;
188
189 int bdb = 0, bd_block = 0, bdb_tail = 0;
190 int bdb2 = 0, bd_block2 = 0, bdb2_tail = 0;
191 int ldb = 0, ld_block = 0, ldb_tail = 0;
192 int ldb2 = 0, ld_block2 = 0, ldb2_tail = 0;
193 int rdb = 0, rd_block = 0, rdb_tail = 0;
194 int rd_step = 0, ld_step = 0;
195
196 impl::data_type_t dt_a = data_type::undef;
197 impl::data_type_t dt_c = data_type::undef;
198 impl::data_type_t dt_b = data_type::undef;
199 impl::data_type_t dt_d = data_type::undef;
200 impl::data_type_t dt_bias = data_type::undef;
201
202 int typesize_A = 0;
203 int typesize_B = 0;
204 int typesize_C = 0;
205 int typesize_D = 0;
206 int typesize_bias = 0;
207
208 bool is_ymm = false;
209 bool is_zmm = false;
210 bool is_tmm = false;
211 bool is_int8 = false, is_int8_tmm = false;
212 bool is_bf16 = false, is_bf16_tmm = false, is_bf16_emu = false;
213 bool is_f16 = false, is_f16_tmm = false;
214 bool is_f32 = false;
215 bool is_bf32 = false;
216
217 dim_t stride_a = 0; // Offset in bytes
218 dim_t stride_b = 0;
219
220 brgemm_layout_t layout = brgemm_layout_undef;
221 brgemm_batch_kind_t type;
222
223 bool load_nt_A = false;
224 bool load_nt_B = false;
225 bool embd_bcst = false;
226 bool is_dgmm = false; // set to true in brdgmm_desc_init
227 bool with_bias = false;
228 bool with_sum = false;
229 float sum_scale = 0.0f;
230 int32_t sum_zp = 0;
231 impl::data_type_t sum_dt;
232 bool with_eltwise = false;
233 bool with_binary = false;
234 bool with_scales = false;
235 bool req_cal_comp_pads = false;
236 bool req_s8s8_compensation = false;
237 brgemm_broadcast_t zp_type_a = brgemm_broadcast_t::none;
238 brgemm_broadcast_t zp_type_b = brgemm_broadcast_t::none;
239 brgemm_broadcast_t zp_type_c = brgemm_broadcast_t::none;
240
241 int is_oc_scale = 0;
242
243 const primitive_attr_t *attr = nullptr;
244 const memory_desc_t *dst_md = nullptr;
245
246 brgemm_attr_t brgattr;
247 static constexpr int MAX_VPAD = 100;
248 static constexpr int AMX_TILES_NUM = 8;
249
250 int is_M_tail;
251
252 bool interleave_tilestores_ = false;
253
254 brgemm_prf_t prfA, prfB, prfC;
255
256 bool is_row_major() const {
257 assert(layout != brgemm_layout_undef);
258 return layout == brgemm_row_major;
259 }
260
261 // Tile register decomposition
262 int get_bd_block2() const noexcept {
263 return (bdb_tail) ? bd_block2 + 1 : bd_block2;
264 }
265 int get_ld_block2() const noexcept {
266 return (ldb_tail) ? ld_block2 + 1 : ld_block2;
267 }
268 int get_num_C_tiles() const noexcept {
269 return get_bd_block2() * get_ld_block2();
270 }
271 int get_C_tensor(int m, int n, bool m_tail = false,
272 bool n_tail = false) const noexcept {
273 auto M = m_tail ? get_bd_block2() - 1 : m;
274 auto N = n_tail ? get_ld_block2() - 1 : n;
275 return (M * get_ld_block2() + N);
276 }
277
278 int tiles_for_A() const noexcept {
279 return (AMX_TILES_NUM - get_num_C_tiles() - 1);
280 }
281
282 int get_A_tensor(int m, bool m_tail = false) const noexcept {
283 auto full_A_tiles = get_num_A_tiles() - (bdb_tail ? 1 : 0);
284 auto M = m_tail ? get_num_A_tiles() - 1 : m % full_A_tiles;
285 return (get_num_C_tiles() + M);
286 }
287
288 int get_num_A_tiles() const noexcept {
289 return nstl::min(get_bd_block2(), tiles_for_A());
290 }
291
292 int tiles_for_B() const noexcept {
293 return (AMX_TILES_NUM - get_num_C_tiles() - get_num_A_tiles());
294 }
295
296 int get_B_tensor(int n, bool n_tail = false) const noexcept {
297 auto full_B_tiles = get_num_B_tiles() - (ldb_tail ? 1 : 0);
298 auto N = n_tail ? get_num_B_tiles() - 1 : n % full_B_tiles;
299 return (get_num_C_tiles() + get_num_A_tiles() + N);
300 }
301
302 int get_num_B_tiles() const noexcept {
303 return nstl::min(get_ld_block2(), tiles_for_B());
304 }
305
306 int get_wsp_buffer_size() const noexcept {
307 int sz = 0;
308 if (is_tmm) {
309 constexpr int tilesize = 1024;
310 sz = get_num_C_tiles() * tilesize; // postops buffer
311 if (is_bf32) {
312 const int n_bdb = bd_block2;
313 const int n_rdb = rdb + (rdb_tail != 0);
314 const int n_ldb = ldb + (ldb_tail != 0);
315 const int downcvt_tiles
316 = brgattr.max_bs * n_rdb * (n_bdb + n_ldb);
317 sz += downcvt_tiles * tilesize;
318 }
319 }
320 return sz;
321 }
322
323 bool is_b_data_layout_vnni() {
324 // True in general, only exception is f16 with avx512_core_fp16.
325 // We also return `true` for bf32 (brgattr.fpmath_mode_ = bf16),
326 // because the data transformation to vnni layout is internal
327 // and transparent to user.
328 return !(dt_b == data_type::f16 && isa_impl == avx512_core_fp16);
329 }
330};
331
332struct brgemm_kernel_params_t {
333 const void *ptr_A;
334 const void *ptr_B;
335 const brgemm_batch_element_t *batch;
336 void *ptr_C;
337
338 const void *ptr_bias;
339 void *ptr_D;
340
341 /* kernel takes single pointer scales, but configuration relies on a
342 * combination of arg scales. This helps to reuse attributes from
343 * primitives, but requires them to pre-compute
344 * scales = src_scale * wei_scale[:]
345 */
346 const void *ptr_scales;
347 void *ptr_buf;
348
349 size_t do_post_ops;
350 size_t do_apply_comp;
351 size_t BS;
352
353 /*
354 * ptr to table of void * elements that are pointers to post_op binary
355 * src1 tensors
356 */
357 const void *post_ops_binary_rhs_arg_vec;
358 size_t oc_logical_off;
359 size_t first_mb_matrix_addr_off;
360 size_t dst_row_logical_off;
361
362 char *data_C_ptr_;
363
364 const void *a_zp_compensations = nullptr;
365 const void *b_zp_compensations = nullptr;
366 const void *c_zp_values = nullptr;
367 size_t skip_accm = 0;
368 int32_t zp_a_val = 1;
369};
370
371template <cpu_isa_t isa, typename Vmm>
372struct jit_brgemm_kernel_t;
373struct jit_brgemm_amx_uker_base_t;
374template <cpu_isa_t isa, typename Vmm>
375struct jit_brdgmm_kernel_base_t;
376
377struct brgemm_kernel_t {
378 brgemm_kernel_t() {};
379 virtual ~brgemm_kernel_t() {};
380 virtual status_t create_kernel() = 0;
381 virtual void operator()(brgemm_kernel_params_t *) const = 0;
382};
383
384template <cpu_isa_t isa, typename Vmm>
385struct brgemm_kernel_common_t : public brgemm_kernel_t {
386 brgemm_kernel_common_t(const brgemm_t abrd);
387 ~brgemm_kernel_common_t();
388
389 status_t create_kernel();
390 void operator()(brgemm_kernel_params_t *) const;
391
392private:
393 jit_brgemm_kernel_t<isa, Vmm> *brgemm_kernel_ = nullptr;
394
395 DNNL_DISALLOW_COPY_AND_ASSIGN(brgemm_kernel_common_t);
396};
397
398struct brgemm_amx_uker_t : public brgemm_kernel_t {
399 brgemm_amx_uker_t(const brgemm_t abrd);
400 ~brgemm_amx_uker_t();
401
402 status_t create_kernel();
403 void operator()(brgemm_kernel_params_t *) const;
404
405private:
406 jit_brgemm_amx_uker_base_t *brgemm_kernel_ = nullptr;
407
408 DNNL_DISALLOW_COPY_AND_ASSIGN(brgemm_amx_uker_t);
409};
410
411template <cpu_isa_t isa, typename Vmm>
412struct brdgmm_kernel_t : public brgemm_kernel_t {
413 brdgmm_kernel_t(const brgemm_t abrd);
414 ~brdgmm_kernel_t();
415
416 status_t create_kernel();
417 void operator()(brgemm_kernel_params_t *) const;
418
419private:
420 jit_brdgmm_kernel_base_t<isa, Vmm> *brgemm_kernel_ = nullptr;
421
422 DNNL_DISALLOW_COPY_AND_ASSIGN(brdgmm_kernel_t);
423};
424
425/// @param bias Vector of bias (vector length is N)
426/// @param scales Vector of scales (vector length is N)
427/// @param binary_post_ops_rhs - Ptr to table of pointers to tensors used as rhs
428/// in binary post-operation { void* binary_op_tensor1, ...,
429/// void* binary_op_tensor_n}
430/// @param oc_logical_off - Used in binary postops in per_oc bcast strategy.
431/// Offset to start oc processed by given thread in elements.
432/// @param dst_row_logical_off - Used in binary postops in per_oc bcast
433/// strategy. Offset to start oc processed by given thread in elements.
434/// @param a_zp_compensations - Pre-computed compensations for A matrix zero
435/// point values.
436/// @param b_zp_compensations - Pre-computed compensations for B matrix zero
437/// point values.
438/// @param c_zp_values - C matrix zero point values.
439/// @param skip_accumulation - specifies whether to skip accumulation when
440/// computing post-ops. `Beta` value from descriptor affects final
441/// accumulator values taken.
442///
443struct brgemm_post_ops_data_t {
444 brgemm_post_ops_data_t() = default;
445 brgemm_post_ops_data_t(const void *bias, const float *scales,
446 const void *binary_post_ops_rhs, size_t oc_logical_off,
447 const size_t dst_row_logical_off = 0, char *data_C_ptr_ = nullptr,
448 const size_t first_mb_matrix_addr_off = 0,
449 const void *a_zp_compensations = nullptr,
450 const void *b_zp_compensations = nullptr,
451 const void *c_zp_values = nullptr, bool skip_accumulation = false,
452 int32_t zp_a_val = 1, bool do_only_comp = false,
453 bool do_only_zp_a_val = false)
454 : bias(bias)
455 , scales(scales)
456 , binary_post_ops_rhs(binary_post_ops_rhs)
457 , oc_logical_off(oc_logical_off)
458 , dst_row_logical_off(dst_row_logical_off)
459 , data_C_ptr_(data_C_ptr_)
460 , first_mb_matrix_addr_off(first_mb_matrix_addr_off)
461 , a_zp_compensations(a_zp_compensations)
462 , b_zp_compensations(b_zp_compensations)
463 , c_zp_values(c_zp_values)
464 , skip_accumulation(skip_accumulation)
465 , zp_a_val {zp_a_val}
466 , do_only_comp {do_only_comp}
467 , do_only_zp_a_val {do_only_zp_a_val} {}
468
469 const void *bias = nullptr;
470 const float *scales = nullptr;
471 const void *binary_post_ops_rhs = nullptr;
472 size_t oc_logical_off = 0;
473 size_t dst_row_logical_off = 0;
474 char *data_C_ptr_ = nullptr;
475 size_t first_mb_matrix_addr_off = 0;
476 const void *a_zp_compensations = nullptr;
477 const void *b_zp_compensations = nullptr;
478 const void *c_zp_values = nullptr;
479 const bool skip_accumulation = false;
480 int32_t zp_a_val = 1;
481 const bool do_only_comp = false;
482 const bool do_only_zp_a_val = false;
483};
484
485} // namespace x64
486} // namespace cpu
487} // namespace impl
488} // namespace dnnl
489
490#endif
491
492//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
493