1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_BRGEMM_BRGEMM_HPP |
18 | #define CPU_X64_BRGEMM_BRGEMM_HPP |
19 | |
20 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | /// Initializes a BRGEMM descriptor |
27 | /// |
28 | /// @param brg Output BRGEMM descriptor |
29 | /// @param isa Target ISA of BRGEMM kernel |
30 | /// If isa is equal to 'isa_undef' maximum supported ISA on current |
31 | /// hardware will be used for BRGEMM kernel generation |
32 | /// @param type Type of batch |
33 | /// @param dt_a Data type of A matrix, can be |
34 | /// AVX512: f32, u8(row-major layout), s8(column-major layout), bf16, f16 |
35 | /// AMX: u8, s8, bf16, f16 |
36 | /// @param dt_b Data type of B matrix |
37 | /// AVX512: f32, s8(row-major layout), u8(column-major layout), bf16, f16 |
38 | /// AMX: u8, s8, bf16, f16 |
39 | /// @note |
40 | /// Data type of matrix C depends on data types of matrices A and B |
41 | /// If A and B have integer u8/s8 data type, C has int32 data type |
42 | /// If A and B have bf16 or f16 or f32 data type, C has f32 data type |
43 | /// @param transA Specifies the form of A used in the matrix multiplication |
44 | /// 'false' - A is not transposed, 'true' - A is transposed |
45 | /// @param transB Specifies the form of B used in the matrix multiplication |
46 | /// 'false' - B is not transposed, 'true' - B is transposed |
47 | /// @param layout Specifies whether two-dimensional array storage is row-major |
48 | /// (brgemm_row_major) or column-major (brgemm_col_major). |
49 | /// @param alpha Specifies the scalar alpha |
50 | /// @param beta Specifies the scalar beta |
51 | /// @param LDA Specifies the leading dimension of matrix A. |
52 | /// LDA must be at least max(1, K) |
53 | /// @param LDB Specifies the leading dimension of matrix B. |
54 | /// LDB must be at least max(1, N) |
55 | /// @param LDC Specifies the leading dimension of matrix C. |
56 | /// LDC must be at least max(1, N) |
57 | /// @param M Specifies the number of rows of the matrix A and of the matrix C. |
58 | /// @param N Specifies the number of columns of the matrix B and |
59 | /// the number of columns of the matrix C |
60 | /// @param K Specifies the number of columns of the matrix A and |
61 | /// the number of rows of the matrix B |
62 | /// @param strides Strides between the matrices in the batch. Can be nullptr. |
63 | /// |
64 | status_t DNNL_API brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa, |
65 | brgemm_batch_kind_t type, impl::data_type_t dt_a, |
66 | impl::data_type_t dt_b, bool transA, bool transB, |
67 | brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB, |
68 | dim_t LDC, dim_t M, dim_t N, dim_t K, |
69 | const brgemm_strides_t *strides = nullptr); |
70 | |
71 | /// Initializes a BRGEMM descriptor with B matrix as a diagonal matrix |
72 | /// represented in packed vector format. |
73 | /// |
74 | /// @param brg Output BRGEMM descriptor |
75 | /// @param isa Target ISA of BRGEMM kernel |
76 | /// If isa is equal to 'isa_undef' maximum supported ISA on current |
77 | /// hardware will be used for BRGEMM kernel generation |
78 | /// @param type Type of batch |
79 | /// @param dt_a Data type of A matrix can be: f32, u8, bf16, f16 |
80 | /// @param dt_b Data type of B vector can be: f32, s8, bf16, f16 |
81 | /// @note |
82 | /// Data type of matrix C depends on data types of matrices A and vector B |
83 | /// If A and B have integer u8/s8 data type, C has int32 data type |
84 | /// If A and B have bf16 or f16 or f32 data type, C has f32 data type |
85 | /// @param transA Specifies the form of A used in the matrix multiplication |
86 | /// 'false' - A is not transposed, 'true' - A is transposed |
87 | /// @param layout Specifies whether two-dimensional array storage is row-major |
88 | /// (brgemm_row_major) or column-major (brgemm_col_major). |
89 | /// @param alpha Specifies the scalar alpha |
90 | /// @param beta Specifies the scalar beta |
91 | /// @param LDA Specifies the leading dimension of matrix A. |
92 | /// LDA must be at least max(1, N) |
93 | /// @param LDC Specifies the leading dimension of matrix C. |
94 | /// LDC must be at least max(1, N) |
95 | /// @param M Specifies the number of rows of the matrix A and C. |
96 | /// @param N Specifies the number of columns of the matrix A and C. |
97 | /// |
98 | status_t DNNL_API brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa, |
99 | brgemm_batch_kind_t type, impl::data_type_t dt_a, |
100 | impl::data_type_t dt_b, bool transA, brgemm_layout_t layout, |
101 | float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, |
102 | const brgemm_strides_t *strides = nullptr); |
103 | |
104 | /// Adds post-operations to BRGEMM descriptor |
105 | /// |
106 | /// @param brg Output BRGEMM descriptor |
107 | /// @param attr Primitive attributes (can be nullptr). Specifies post-ops |
108 | /// operations |
109 | /// @param dst_md Specifies the memory descriptor of the destination tensor, |
110 | /// needed for binary postops to determine broadcast type, as well as to |
111 | /// determine dst data type. |
112 | /// @param LDD Specifies the leading dimension of matrix D |
113 | /// LDD must be at least max(1, N) |
114 | /// @param dt_bias Specifies the data type Bias |
115 | /// Can be u8, s8, s32, bf16, f16 or fp32 |
116 | /// |
117 | status_t DNNL_API brgemm_desc_set_postops(brgemm_t *brg, |
118 | const primitive_attr_t *attr, const memory_desc_t *dst_md, int LDD, |
119 | impl::data_type_t dt_bias = impl::data_type::undef); |
120 | |
121 | /// Adds BRGEMM attributes to BRGEMM descriptor |
122 | /// |
123 | /// @param brg Output BRGEMM descriptor |
124 | /// @param brgattr Specifies kernel attributes and hints: virtual padding, |
125 | /// maximum batch size, kernel loop order etc. |
126 | /// |
127 | status_t DNNL_API brgemm_desc_set_attr( |
128 | brgemm_t *brg, const brgemm_attr_t &brgattr); |
129 | |
130 | /// Generates a BRGEMM kernel based on descriptor |
131 | /// |
132 | /// @param brg_kernel Output BRGEMM kernel |
133 | /// @param brg BRGEMM descriptor |
134 | /// |
135 | status_t DNNL_API brgemm_kernel_create( |
136 | brgemm_kernel_t **brg_kernel, const brgemm_t &brg); |
137 | |
138 | /// Destroys a BRGEMM kernel |
139 | /// |
140 | /// @param brg_kernel BRGEMM kernel |
141 | /// |
142 | status_t DNNL_API brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel); |
143 | |
144 | /// Execute BRGEMM kernel (brgemm_addr version) |
145 | /// |
146 | /// @note |
147 | /// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM |
148 | /// descriptor |
149 | /// |
150 | /// @note |
151 | /// In row major mode matrix B (matrix A for column major) is expected to be |
152 | /// in a VNNI-friendly format, which requires 4 consecutive elements of K |
153 | /// dimension for int8 data type, 2 elements for bfloat16 data type and no |
154 | /// requirements for f32 and f16 data types. |
155 | /// |
156 | /// @param brg_kernel BRGEMM kernel |
157 | /// @param bs Specifies the size of batch |
158 | /// @param batch Array of batch elements containing pointers to matrices |
159 | /// A,B and virtual padding for matrices A |
160 | /// @param ptr_C Pointer to destination matrix C |
161 | /// @param scratch Scratchpad memory needed in several scenarios: |
162 | /// * Where: AMX+ hardware; When: always; For: buffer for tiles store. |
163 | /// * In rest scenarios is not used. |
164 | /// |
165 | void DNNL_API brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, |
166 | const brgemm_batch_element_t *batch, void *ptr_C, |
167 | void *scratch = nullptr); |
168 | |
169 | /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) |
170 | /// |
171 | /// @note |
172 | /// Only BRGEMM kernel will be executed even if post-ops are added to BRGEMM |
173 | /// descriptor |
174 | /// |
175 | /// @note |
176 | /// See the second note for `brgemm_kernel_execute` API. |
177 | /// |
178 | /// @param brg_kernel BRGEMM kernel |
179 | /// @param bs Specifies the size of batch |
180 | /// @param addr_A Pointer to first matrix A in the batch |
181 | /// @param addr_B Pointer to first matrix B in the batch |
182 | /// @param batch Array of batch elements containing offsets to matrices A,B |
183 | /// and virtual padding for matrix A |
184 | /// @param ptr_C Pointer to destination matrix C |
185 | /// @param scratch Scratchpad memory needed in several scenarios: |
186 | /// * Where: AMX+ hardware; When: always; For: buffer for tiles store. |
187 | /// * In rest scenarios is not used. |
188 | /// |
189 | void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, |
190 | const void *addr_A, const void *addr_B, |
191 | const brgemm_batch_element_t *batch, void *ptr_C, |
192 | void *scratch = nullptr); |
193 | |
194 | /// Execute BRGEMM kernel (brgemm_addr version) |
195 | /// |
196 | /// @note |
197 | /// BRGEMM kernel and post-operations will be executed |
198 | /// |
199 | /// @note |
200 | /// See the second note for `brgemm_kernel_execute` API. |
201 | /// |
202 | /// @param brg_kernel BRGEMM kernel |
203 | /// @param bs Specifies the size of batch |
204 | /// @param batch Array of batch elements containing pointers to matrices A,B |
205 | /// and virtual padding for matrices A |
206 | /// @param ptr_C Pointer to matrix C |
207 | /// @param ptr_D Pointer to destination matrix D |
208 | /// @param post_ops_data Specifies tensors and data used in post processing |
209 | /// phase |
210 | /// @param scratch Scratchpad memory needed in several scenarios: |
211 | /// * Where: AMX+ hardware; When: always; For: buffer for tiles store. |
212 | /// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer. |
213 | /// * In rest scenarios is not used. |
214 | /// |
215 | void DNNL_API brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, |
216 | int bs, const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, |
217 | const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr); |
218 | |
219 | /// Execute BRGEMM kernel (brgemm_offs and brgemm_strd version) |
220 | /// |
221 | /// @note |
222 | /// BRGEMM kernel and post-operations will be executed |
223 | /// |
224 | /// @note |
225 | /// See the second note for `brgemm_kernel_execute` API. |
226 | /// |
227 | /// @param brg_kernel BRGEMM kernel |
228 | /// @param bs Specifies the size of batch |
229 | /// @param addr_A Pointer to first matrix A in the batch |
230 | /// @param addr_B Pointer to first matrix B in the batch |
231 | /// @param batch Array of batch elements containing offsets to matrices A,B |
232 | /// and virtual padding for matrices A |
233 | /// @param ptr_C Pointer to destination matrix C |
234 | /// @param ptr_D Pointer to destination matrix D |
235 | /// @param post_ops_data Specifies tensors and data used in post processing |
236 | /// phase |
237 | /// @param scratch Scratchpad memory needed in several scenarios: |
238 | /// * Where: AMX+ hardware; When: always; For: buffer for tiles store. |
239 | /// * Where: pre-VNNI hardware; When: s8s8 kernel; For: compensation buffer. |
240 | /// * In rest scenarios is not used. |
241 | /// |
242 | void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, |
243 | const void *addr_A, const void *addr_B, |
244 | const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, |
245 | const brgemm_post_ops_data_t &post_ops_data, void *scratch = nullptr); |
246 | |
247 | /// AMX utilities: Creates a palette based on BRGEMM descriptor |
248 | /// |
249 | /// @note |
250 | /// This call expects brgemm_t object completely set up, thus, used after |
251 | /// `brgemm_desc_set_attr` call for non-empty attributes. |
252 | /// |
253 | /// @note |
254 | /// Caller is expected to subsequently configure AMX tiles by calling |
255 | /// amx_tile_configure(palette). |
256 | /// |
257 | /// @param brg BRGEMM descriptor |
258 | /// @param palette 64 bytes array contains tiles configuration |
259 | /// |
260 | status_t DNNL_API brgemm_init_tiles(const brgemm_t &brg, char palette[64]); |
261 | |
262 | } // namespace x64 |
263 | } // namespace cpu |
264 | } // namespace impl |
265 | } // namespace dnnl |
266 | |
267 | #endif |
268 | |
269 | //vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
270 | |