1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_BRGEMM_BRGEMM_HPP
18#define CPU_X64_BRGEMM_BRGEMM_HPP
19
20#include "cpu/x64/brgemm/brgemm_types.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26/// Initializes a BRGEMM descriptor
27///
28/// @param brg Output BRGEMM descriptor
29/// @param isa Target ISA of BRGEMM kernel
30/// If isa is equal to 'isa_undef' maximum supported ISA on current
31/// hardware will be used for BRGEMM kernel generation
32/// @param type Type of batch
33/// @param dt_a Data type of A matrix, can be
34/// AVX512: f32, u8(row-major layout), s8(column-major layout), bf16, f16
35/// AMX: u8, s8, bf16, f16
36/// @param dt_b Data type of B matrix
37/// AVX512: f32, s8(row-major layout), u8(column-major layout), bf16, f16
38/// AMX: u8, s8, bf16, f16
39/// @note
40/// Data type of matrix C depends on data types of matrices A and B
41/// If A and B have integer u8/s8 data type, C has int32 data type
42/// If A and B have bf16 or f16 or f32 data type, C has f32 data type
43/// @param transA Specifies the form of A used in the matrix multiplication
44/// 'false' - A is not transposed, 'true' - A is transposed
45/// @param transB Specifies the form of B used in the matrix multiplication
46/// 'false' - B is not transposed, 'true' - B is transposed
47/// @param layout Specifies whether two-dimensional array storage is row-major
48/// (brgemm_row_major) or column-major (brgemm_col_major).
49/// @param alpha Specifies the scalar alpha
50/// @param beta Specifies the scalar beta
51/// @param LDA Specifies the leading dimension of matrix A.
52/// LDA must be at least max(1, K)
53/// @param LDB Specifies the leading dimension of matrix B.
54/// LDB must be at least max(1, N)
55/// @param LDC Specifies the leading dimension of matrix C.
56/// LDC must be at least max(1, N)
57/// @param M Specifies the number of rows of the matrix A and of the matrix C.
58/// @param N Specifies the number of columns of the matrix B and
59/// the number of columns of the matrix C
60/// @param K Specifies the number of columns of the matrix A and
61/// the number of rows of the matrix B
62/// @param strides Strides between the matrices in the batch. Can be nullptr.
63///
64status_t DNNL_API brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
65 brgemm_batch_kind_t type, impl::data_type_t dt_a,
66 impl::data_type_t dt_b, bool transA, bool transB,
67 brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
68 dim_t LDC, dim_t M, dim_t N, dim_t K,
69 const brgemm_strides_t *strides = nullptr);
70
71/// Initializes a BRGEMM descriptor with B matrix as a diagonal matrix
72/// represented in packed vector format.
73///
74/// @param brg Output BRGEMM descriptor
75/// @param isa Target ISA of BRGEMM kernel
76/// If isa is equal to 'isa_undef' maximum supported ISA on current
77/// hardware will be used for BRGEMM kernel generation
78/// @param type Type of batch
79/// @param dt_a Data type of A matrix can be: f32, u8, bf16, f16
80/// @param dt_b Data type of B vector can be: f32, s8, bf16, f16
81/// @note
82/// Data type of matrix C depends on data types of matrices A and vector B
83/// If A and B have integer u8/s8 data type, C has int32 data type
84/// If A and B have bf16 or f16 or f32 data type, C has f32 data type
85/// @param transA Specifies the form of A used in the matrix multiplication
86/// 'false' - A is not transposed, 'true' - A is transposed
87/// @param layout Specifies whether two-dimensional array storage is row-major
88/// (brgemm_row_major) or column-major (brgemm_col_major).
89/// @param alpha Specifies the scalar alpha
90/// @param beta Specifies the scalar beta
91/// @param LDA Specifies the leading dimension of matrix A.
92/// LDA must be at least max(1, N)
93/// @param LDC Specifies the leading dimension of matrix C.
94/// LDC must be at least max(1, N)
95/// @param M Specifies the number of rows of the matrix A and C.
96/// @param N Specifies the number of columns of the matrix A and C.
97///
98status_t DNNL_API brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
99 brgemm_batch_kind_t type, impl::data_type_t dt_a,
100 impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
101 float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
102 const brgemm_strides_t *strides = nullptr);
103
104/// Adds post-operations to BRGEMM descriptor
105///
106/// @param brg Output BRGEMM descriptor
107/// @param attr Primitive attributes (can be nullptr). Specifies post-ops
108/// operations
109/// @param dst_md Specifies the memory descriptor of the destination tensor,
110/// needed for binary postops to determine broadcast type, as well as to
111/// determine dst data type.
112/// @param LDD Specifies the leading dimension of matrix D
113/// LDD must be at least max(1, N)
114/// @param dt_bias Specifies the data type Bias
115/// Can be u8, s8, s32, bf16, f16 or fp32
116///
117status_t DNNL_API brgemm_desc_set_postops(brgemm_t *brg,
118 const primitive_attr_t *attr, const memory_desc_t *dst_md, int LDD,
119 impl::data_type_t dt_bias = impl::data_type::undef);
120
121/// Adds BRGEMM attributes to BRGEMM descriptor
122///
123/// @param brg Output BRGEMM descriptor
124/// @param brgattr Specifies kernel attributes and hints: virtual padding,
125/// maximum batch size, kernel loop order etc.
126///
127status_t DNNL_API brgemm_desc_set_attr(
128 brgemm_t *brg, const brgemm_attr_t &brgattr);
129
130/// Generates a BRGEMM kernel based on descriptor
131///
132/// @param brg_kernel Output BRGEMM kernel
133/// @param brg BRGEMM descriptor
134///
135status_t DNNL_API brgemm_kernel_create(
136 brgemm_kernel_t **brg_kernel, const brgemm_t &brg);
137
138/// Destroys a BRGEMM kernel
139///
140/// @param brg_kernel BRGEMM kernel
141///
142status_t DNNL_API brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel);
143
144/// Execute BRGEMM kernel (brgemm_addr version)
145///
146/// @note
147/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
148/// descriptor
149///
150/// @note
151/// In row major mode matrix B (matrix A for column major) is expected to be
152/// in a VNNI-friendly format, which requires 4 consecutive elements of K
153/// dimension for int8 data type, 2 elements for bfloat16 data type and no
154/// requirements for f32 and f16 data types.
155///
156/// @param brg_kernel BRGEMM kernel
157/// @param bs Specifies the size of batch
158/// @param batch Array of batch elements containing pointers to matrices
159/// A,B and virtual padding for matrices A
160/// @param ptr_C Pointer to destination matrix C
161/// @param scratch Scratchpad memory needed in several scenarios:
162/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
163/// * In rest scenarios is not used.
164///
165void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
166 const brgemm_batch_element_t *batch, void *ptr_C,
167 void *scratch = nullptr);
168
169/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
170///
171/// @note
172/// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM
173/// descriptor
174///
175/// @note
176/// See the second note for `brgemm_kernel_execute` API.
177///
178/// @param brg_kernel BRGEMM kernel
179/// @param bs Specifies the size of batch
180/// @param addr_A Pointer to first matrix A in the batch
181/// @param addr_B Pointer to first matrix B in the batch
182/// @param batch Array of batch elements containing offsets to matrices A,B
183/// and virtual padding for matrix A
184/// @param ptr_C Pointer to destination matrix C
185/// @param scratch Scratchpad memory needed in several scenarios:
186/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
187/// * In rest scenarios is not used.
188///
189void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
190 const void *addr_A, const void *addr_B,
191 const brgemm_batch_element_t *batch, void *ptr_C,
192 void *scratch = nullptr);
193
194/// Execute BRGEMM kernel (brgemm_addr version)
195///
196/// @note
197/// BRGEMM kernel and post-operations will be executed
198///
199/// @note
200/// See the second note for `brgemm_kernel_execute` API.
201///
202/// @param brg_kernel BRGEMM kernel
203/// @param bs Specifies the size of batch
204/// @param batch Array of batch elements containing pointers to matrices A,B
205/// and virtual padding for matrices A
206/// @param ptr_C Pointer to matrix C
207/// @param ptr_D Pointer to destination matrix D
208/// @param post_ops_data Specifies tensors and data used in post processing
209/// phase
210/// @param scratch Scratchpad memory needed in several scenarios:
211/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
212/// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer.
213/// * In rest scenarios is not used.
214///
215void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel,
216 int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
217 const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr);
218
219/// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version)
220///
221/// @note
222/// BRGEMM kernel and post-operations will be executed
223///
224/// @note
225/// See the second note for `brgemm_kernel_execute` API.
226///
227/// @param brg_kernel BRGEMM kernel
228/// @param bs Specifies the size of batch
229/// @param addr_A Pointer to first matrix A in the batch
230/// @param addr_B Pointer to first matrix B in the batch
231/// @param batch Array of batch elements containing offsets to matrices A,B
232/// and virtual padding for matrices A
233/// @param ptr_C Pointer to destination matrix C
234/// @param ptr_D Pointer to destination matrix D
235/// @param post_ops_data Specifies tensors and data used in post processing
236/// phase
237/// @param scratch Scratchpad memory needed in several scenarios:
238/// * Where: AMX+ hardware; When: always; For: buffer for tiles store.
239/// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer.
240/// * In rest scenarios is not used.
241///
242void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
243 const void *addr_A, const void *addr_B,
244 const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
245 const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr);
246
247/// AMX utilities: Creates a palette based on BRGEMM descriptor
248///
249/// @note
250/// This call expects brgemm_t object completely set up, thus, used after
251/// `brgemm_desc_set_attr` call for non-empty attributes.
252///
253/// @note
254/// Caller is expected to subsequently configure AMX tiles by calling
255/// amx_tile_configure(palette).
256///
257/// @param brg BRGEMM descriptor
258/// @param palette 64 bytes array contains tiles configuration
259///
260status_t DNNL_API brgemm_init_tiles(const brgemm_t &brg, char palette[64]);
261
262} // namespace x64
263} // namespace cpu
264} // namespace impl
265} // namespace dnnl
266
267#endif
268
269//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
270