1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_brgemm_inner_product_utils.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace cpu { |
22 | namespace x64 { |
23 | |
24 | using namespace dnnl::impl::status; |
25 | using namespace dnnl::impl::format_tag; |
26 | using namespace dnnl::impl::memory_tracking::names; |
27 | using namespace dnnl::impl::utils; |
28 | |
29 | using namespace prop_kind; |
30 | using namespace data_type; |
31 | |
32 | namespace brgemm_inner_product_utils { |
33 | |
34 | int get_brg_kernel_index(const jit_brgemm_primitive_conf_t &jbgp, |
35 | bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail, |
36 | bool is_K_tail) { |
37 | int idx = 16 * (int)is_bs_tail + 8 * (int)do_initialization |
38 | + 4 * (int)is_M_tail + 2 * (int)is_N_tail + (int)is_K_tail; |
39 | |
40 | assert(idx < max_num_brg_kernels_ip); |
41 | return idx; |
42 | } |
43 | |
44 | int get_os_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust, |
45 | bool is_adjustment) { |
46 | const bool is_amx_int8 |
47 | = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8); |
48 | const bool is_amx_xf16 = is_superset(jbgp.isa, avx512_core_amx) |
49 | && one_of(jbgp.wei_dt, bf16, f16); |
50 | const bool is_avx512_bf16 = jbgp.isa == avx512_core_bf16; |
51 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
52 | const bool is_bf32 = jbgp.is_bf32; |
53 | |
54 | int max_os_block = 0; |
55 | int min_os_block = 0; |
56 | |
57 | if (try_to_adjust |
58 | || one_of(jbgp.prop_kind, forward_training, forward_inference)) { |
59 | min_os_block = (is_amx_int8 || is_amx_xf16) ? 16 : 6; |
60 | // Currently gigantic flag is used to separate out transformer_lt and |
61 | // alexnet shapes for which larger os_block gives better performance. |
62 | // TODO: Figure out how much the constraints for `gigantic-ness` can |
63 | // be further loosened. |
64 | const bool is_gigantic_shape |
65 | = jbgp.ic >= 9216 && jbgp.oc >= 4096 && jbgp.os >= 512; |
66 | const bool use_128_block_for_amx |
67 | = is_amx_xf16 && jbgp.os % 128 == 0 && jbgp.oc > 128; |
68 | const bool enable_128_os_blocking |
69 | = use_128_block_for_amx || is_gigantic_shape; |
70 | max_os_block = enable_128_os_blocking ? 128 : 64; |
71 | // Work done by each thread is given by: |
72 | // (nb_oc / nb_oc_blocking) * (nb_os / nb_os_blocking) |
73 | // As a first approximation we take nb_oc_blocking = nb_os_blocking = 1 |
74 | // Furthermore, we recall that |
75 | // nb_oc = oc / oc_block |
76 | // nb_os = os / os_block |
77 | // |
78 | // For f32 data type our objective is to determine the optimal value |
79 | // of os_block such that the work amount per thread ~ 2 |
80 | if (is_f32 && !is_bf32) { |
81 | const bool small_work_amt_per_thread |
82 | = div_up(jbgp.os, max_os_block) * jbgp.nb_oc |
83 | < 1.8f * jbgp.nthr; |
84 | if (small_work_amt_per_thread) |
85 | max_os_block = saturate(16, max_os_block, |
86 | div_up(jbgp.os * jbgp.nb_oc, 2 * jbgp.nthr)); |
87 | } |
88 | } else if (jbgp.prop_kind == backward_data) { |
89 | int plat_max_os_block = 0; |
90 | if (is_amx_xf16) { |
91 | plat_max_os_block |
92 | = (jbgp.ic >= 512 && jbgp.oc / jbgp.ic <= 4) ? 128 : 64; |
93 | } else if (is_avx512_bf16) { |
94 | plat_max_os_block = (jbgp.ic > 256) ? 128 : 64; |
95 | } else { |
96 | plat_max_os_block = 64; |
97 | } |
98 | max_os_block = nstl::min(plat_max_os_block, jbgp.os); |
99 | min_os_block = is_amx_xf16 ? 16 : 6; |
100 | } else if (jbgp.prop_kind == backward_weights) { |
101 | constexpr int amx_xf16_row = 64; |
102 | constexpr int amx_xf16_half_row = amx_xf16_row / 2; |
103 | // ensure that os_tail <= amx_xf16_half_row |
104 | const bool use_large_os_block = (jbgp.os >= amx_xf16_row) |
105 | && (jbgp.os % amx_xf16_row) <= amx_xf16_half_row; |
106 | return is_amx_xf16 |
107 | ? (use_large_os_block ? amx_xf16_row : amx_xf16_half_row) |
108 | : 16; |
109 | } else |
110 | assert(!"unsupported case" ); |
111 | |
112 | if (is_adjustment) max_os_block /= 2; |
113 | int os_block = 1; |
114 | for (int osb = max_os_block; osb >= min_os_block; osb--) { |
115 | if (osb == 0) break; |
116 | if (jbgp.os % osb == 0) { |
117 | os_block = osb; |
118 | break; |
119 | } |
120 | } |
121 | if (os_block == 1) os_block = nstl::min(jbgp.os, max_os_block); |
122 | |
123 | return os_block; |
124 | } |
125 | |
126 | std::vector<format_tag_t> get_desired_weights_tag( |
127 | const jit_brgemm_primitive_conf_t &jbgp) { |
128 | using namespace format_tag; |
129 | const int n_sp_dims = jbgp.ndims - 2; |
130 | const bool is_xf16 = utils::one_of(jbgp.wei_dt, bf16, f16); |
131 | const bool is_not_vnni_tag = jbgp.wei_dt == f32 |
132 | || (jbgp.wei_dt == f16 && jbgp.isa == avx512_core_fp16); |
133 | if (is_not_vnni_tag) { |
134 | return {pick(n_sp_dims, OI16i64o, OIw16i64o, OIhw16i64o, OIdhw16i64o), |
135 | pick(n_sp_dims, OI16i32o, OIw16i32o, OIhw16i32o, OIdhw16i32o), |
136 | pick(n_sp_dims, OI16i16o, OIw16i16o, OIhw16i16o, OIdhw16i16o)}; |
137 | } else if (is_xf16) { |
138 | if (is_superset(jbgp.isa, avx512_core_amx)) { |
139 | return {pick(n_sp_dims, OI16i64o2i, OIw16i64o2i, OIhw16i64o2i, |
140 | OIdhw16i64o2i), |
141 | pick(n_sp_dims, OI16i32o2i, OIw16i32o2i, OIhw16i32o2i, |
142 | OIdhw16i32o2i), |
143 | pick(n_sp_dims, OI16i16o2i, OIw16i16o2i, OIhw16i16o2i, |
144 | OIdhw16i16o2i)}; |
145 | } else { |
146 | return {pick(n_sp_dims, OI8i64o2i, OIw8i64o2i, OIhw8i64o2i, |
147 | OIdhw8i64o2i), |
148 | pick(n_sp_dims, OI8i32o2i, OIw8i32o2i, OIhw8i32o2i, |
149 | OIdhw8i32o2i), |
150 | pick(n_sp_dims, OI8i16o2i, OIw8i16o2i, OIhw8i16o2i, |
151 | OIdhw8i16o2i)}; |
152 | } |
153 | } else if (jbgp.wei_dt == data_type::s8) { |
154 | if (jbgp.isa == avx512_core_amx) { |
155 | return {pick(n_sp_dims, OI16i64o4i, OIw16i64o4i, OIhw16i64o4i, |
156 | OIdhw16i64o4i), |
157 | pick(n_sp_dims, OI16i32o4i, OIw16i32o4i, OIhw16i32o4i, |
158 | OIdhw16i32o4i), |
159 | pick(n_sp_dims, OI16i16o4i, OIw16i16o4i, OIhw16i16o4i, |
160 | OIdhw16i16o4i)}; |
161 | } else { |
162 | return {pick(n_sp_dims, OI4i64o4i, OIw4i64o4i, OIhw4i64o4i, |
163 | OIdhw4i64o4i), |
164 | pick(n_sp_dims, OI4i32o4i, OIw4i32o4i, OIhw4i32o4i, |
165 | OIdhw4i32o4i), |
166 | pick(n_sp_dims, OI4i16o4i, OIw4i16o4i, OIhw4i16o4i, |
167 | OIdhw4i16o4i)}; |
168 | } |
169 | } else { |
170 | return std::vector<format_tag_t> {format_tag::undef}; |
171 | } |
172 | } |
173 | |
174 | int get_oc_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust) { |
175 | const bool amx_xf16_bwd_d_noadjust = !try_to_adjust |
176 | && jbgp.prop_kind == backward_data |
177 | && is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32; |
178 | if (amx_xf16_bwd_d_noadjust) { |
179 | constexpr int amx_xf16_row = 64; |
180 | return amx_xf16_row; |
181 | } else if (!jbgp.is_wei_layout_any) { |
182 | std::vector<format_tag_t> weights_tag = get_desired_weights_tag(jbgp); |
183 | if (jbgp.wei_tag == weights_tag[0]) |
184 | return 64; |
185 | else if (jbgp.wei_tag == weights_tag[1]) |
186 | return 32; |
187 | else |
188 | return 16; |
189 | } else { |
190 | if (jbgp.oc >= 64) { |
191 | return 64; |
192 | } else if (jbgp.oc >= 32) { |
193 | return 32; |
194 | } else { |
195 | return 16; |
196 | } |
197 | } |
198 | } |
199 | |
200 | int ip_fwd_get_nb_oc_blocking( |
201 | const jit_brgemm_primitive_conf_t &jbgp, bool is_adjustment) { |
202 | const int small_oc_threshold = 256; |
203 | const int small_os_threshold = 8; |
204 | if (jbgp.os <= small_os_threshold && jbgp.oc <= small_oc_threshold) { |
205 | // For small problems compute all oc blocks as a single chunk to avoid |
206 | // parallel section |
207 | return div_up(jbgp.oc, |
208 | (is_adjustment) ? ip_fwd_get_adjusted_oc_block(jbgp) |
209 | : get_oc_block(jbgp)); |
210 | } else |
211 | return 1; |
212 | } |
213 | |
214 | bool ip_fwd_adjust_thread_balance(const jit_brgemm_primitive_conf_t &jbgp) { |
215 | if (IMPLICATION(jbgp.is_wei_layout_any, |
216 | !is_superset(jbgp.isa, avx512_core_amx))) |
217 | return false; |
218 | |
219 | int os_chunks = div_up(jbgp.os, get_os_block(jbgp, true, false)); |
220 | |
221 | int nb_oc = div_up(jbgp.oc, get_oc_block(jbgp, true)); |
222 | int nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp); |
223 | int oc_chunks = div_up(nb_oc, nb_oc_blocking); |
224 | |
225 | int work_amount = oc_chunks * os_chunks; |
226 | const auto work_per_thread = work_amount / jbgp.nthr; |
227 | float wb_ratio = static_cast<float>(work_amount % jbgp.nthr) / jbgp.nthr; |
228 | |
229 | // return true if work distribution between threads has significant |
230 | // imbalance - amount of work per thread is small and the last iteration |
231 | // is able to load less than half of threads available for the current |
232 | // block sizes |
233 | return (work_per_thread < 3 && wb_ratio > 0.0f && wb_ratio < .5f); |
234 | } |
235 | |
236 | int ip_fwd_get_adjusted_oc_block(const jit_brgemm_primitive_conf_t &jbgp) { |
237 | const bool is_amx_xf16 |
238 | = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32; |
239 | |
240 | // we can't change block size on forward and weights update (external) |
241 | // if layout is set by user, for backward data it can be choosen different |
242 | // from external in this case because copy routine |
243 | const bool not_adjustable_oc_block_size |
244 | = !jbgp.is_wei_layout_any && jbgp.prop_kind != backward_data; |
245 | |
246 | if (IMPLICATION(is_amx_xf16 || jbgp.is_bf32, not_adjustable_oc_block_size)) |
247 | return get_oc_block(jbgp); |
248 | |
249 | int oc_block = get_oc_block(jbgp, true); |
250 | if (ip_fwd_adjust_thread_balance(jbgp)) { |
251 | oc_block = (oc_block > 16) ? oc_block / 2 : oc_block; |
252 | } |
253 | |
254 | constexpr int amx_bf16_half_row = 32; |
255 | // ensure that oc_tail <= amx_bf16_half_row (requirement for brgemm kernel) |
256 | while (jbgp.oc % oc_block > amx_bf16_half_row) |
257 | oc_block /= 2; |
258 | return oc_block; |
259 | } |
260 | |
261 | format_tag_t get_brgemm_ip_weights_tag(cpu_isa_t isa, |
262 | const jit_brgemm_primitive_conf_t &jbgp, |
263 | const memory_desc_t &weights_md) { |
264 | std::vector<format_tag_t> weights_tag = get_desired_weights_tag(jbgp); |
265 | if (!jbgp.is_wei_layout_any) { |
266 | return memory_desc_matches_one_of_tag( |
267 | weights_md, weights_tag[0], weights_tag[1], weights_tag[2]); |
268 | } else { |
269 | const int oc_block = ip_fwd_get_adjusted_oc_block(jbgp); |
270 | const int idx = (oc_block == 64 ? 0 : (oc_block == 32 ? 1 : 2)); |
271 | return weights_tag[idx]; |
272 | } |
273 | } |
274 | |
275 | bool post_ops_ok(jit_brgemm_primitive_conf_t &jbgp, |
276 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { |
277 | using namespace injector; |
278 | |
279 | const auto &post_ops = attr.post_ops_; |
280 | |
281 | return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(), |
282 | {sum, eltwise, binary}, post_ops, &dst_d, |
283 | false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, |
284 | true /*sum_requires_zp_zero*/, |
285 | {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar, |
286 | broadcasting_strategy_t::no_broadcast})); |
287 | } |
288 | |
289 | status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp, |
290 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) { |
291 | const bool is_amx_int8 |
292 | = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8); |
293 | const bool is_amx_xf16 = is_superset(jbgp.isa, avx512_core_amx) |
294 | && one_of(jbgp.wei_dt, bf16, f16) && !jbgp.is_bf32; |
295 | const bool is_int8 = one_of(jbgp.src_dt, u8, s8) && jbgp.wei_dt == s8; |
296 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
297 | jbgp.is_amx = is_superset(jbgp.isa, avx512_core_amx); |
298 | const auto &p = attr.post_ops_; |
299 | jbgp.with_sum = p.find(primitive_kind::sum) != -1; |
300 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
301 | jbgp.with_eltwise = eltwise_ind != -1; |
302 | const int binary_ind = p.find(primitive_kind::binary); |
303 | jbgp.with_binary = binary_ind != -1; |
304 | if (!post_ops_ok(jbgp, attr, dst_d)) return status::unimplemented; |
305 | if (jbgp.with_scales) { |
306 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
307 | jbgp.is_oc_scale = wei_scales.mask_ == 1 << 0; |
308 | |
309 | // only common and per-oc-channel scales are supported |
310 | const bool wei_scales_ok = one_of(wei_scales.mask_, 0, 1 << 0); |
311 | if (!wei_scales_ok) return status::unimplemented; |
312 | } |
313 | const int min_ic_divisor = is_amx_int8 ? 4 : is_amx_xf16 ? 2 : 1; |
314 | |
315 | jbgp.use_buffer_a = jbgp.ic % min_ic_divisor != 0; |
316 | |
317 | constexpr int amx_int8_row = 64; |
318 | constexpr int amx_xf16_row = 32; |
319 | jbgp.ic_block = (is_amx_int8) ? amx_int8_row |
320 | : (is_amx_xf16) ? amx_xf16_row : jbgp.simd_w; |
321 | jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block); |
322 | |
323 | // gemm-based inner product performs better when oc = 1 |
324 | if (is_f32 && !jbgp.is_bf32 && jbgp.oc == 1) return status::unimplemented; |
325 | |
326 | jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp); |
327 | jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block); |
328 | jbgp.nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp); |
329 | |
330 | jbgp.os_block = get_os_block(jbgp, false, false); |
331 | jbgp.nb_os = div_up(jbgp.os, jbgp.os_block); |
332 | |
333 | jbgp.nb_os_blocking = 1; |
334 | // Work done by each thread is given by: |
335 | // (nb_oc / nb_oc_blocking) * (nb_os / nb_os_blocking) |
336 | // For f32 data type we want to increase the nb_os_blocking such that |
337 | // * 1 <= nb_os_blocking <= 8 AND nb_os_blocking <= nb_os |
338 | // * Work amount per thread ~ 2 |
339 | // * NOTE: here nb_oc_blocking = 1 as os is large |
340 | if (jbgp.os > 256 && is_f32 && !jbgp.is_bf32) { |
341 | jbgp.nb_os_blocking = saturate(1, nstl::min(8, jbgp.nb_os), |
342 | nstl::min(nstl::max(jbgp.oc / jbgp.os / 2, 1), |
343 | div_up(jbgp.nb_os * jbgp.nb_oc, 2 * jbgp.nthr))); |
344 | } |
345 | |
346 | // NOTE: comment about is_gigantic_shape is in get_os_block() |
347 | const bool is_gigantic_shape = jbgp.oc >= 4096 && jbgp.os >= 512; |
348 | const int num_work_to_parallel = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking) |
349 | * div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
350 | |
351 | // TODO: although the below heuristic produces good performance for fp32, |
352 | // num_work_to_parallel needs to compared with nthr (instead of nb_ic) |
353 | // and os_block needs some further tuning. |
354 | |
355 | // Use parallel IC reduction for f32 if we have: |
356 | // * very large input channels |
357 | // * work amount in mb and oc dimensions is small compared to nb_ic |
358 | // * number of threads > 1 |
359 | // * not a "gigantic shape" since it already has a lot of parallelism |
360 | // in mb and oc dimensions w/o enabling IC parallelism |
361 | const bool use_parallel_ic_reduction = is_f32 && !jbgp.is_bf32 |
362 | && jbgp.ic > 1024 && num_work_to_parallel < jbgp.nb_ic |
363 | && jbgp.nthr > 1 && !is_gigantic_shape; |
364 | |
365 | // For os > 256, compute all os blocks as a single chunk when performing |
366 | // IC reduction. Note that this condition is empirical |
367 | if (use_parallel_ic_reduction && jbgp.os > 256 && jbgp.nb_os_blocking > 1) |
368 | jbgp.nb_os_blocking = jbgp.nb_os; |
369 | |
370 | jbgp.nb_ic_blocking = 1; |
371 | jbgp.nthr_ic_b = 1; |
372 | const int k_blk = jbgp.is_bf32 ? amx_xf16_row : jbgp.ic_block; |
373 | const int max_nb_ic_blocking = nstl::min(64, jbgp.nb_ic); |
374 | if (IMPLICATION(!is_int8, jbgp.ic <= max_nb_ic_blocking * jbgp.ic_block) |
375 | && everyone_is(1, jbgp.kw, jbgp.kh, jbgp.kd) |
376 | && !jbgp.use_buffer_a) { |
377 | // Optimization: data & weights layouts allow to generate |
378 | // brgemm kernel with K = ic & batch = 1 |
379 | // (K = rnd_dn(ic, ic_block), K_tail = ic % ic_block & batch = 1) |
380 | // instead of K = ic_block & batch = nb_ic_blocking |
381 | jbgp.K = jbgp.ic <= jbgp.ic_block ? jbgp.ic : rnd_dn(jbgp.ic, k_blk); |
382 | jbgp.nb_ic_blocking = jbgp.nb_ic; |
383 | jbgp.gemm_batch_size = 1; |
384 | } else if (!jbgp.use_buffer_a && use_parallel_ic_reduction) { |
385 | const int min_chunk_sz = 16; |
386 | const int num_min_chunk_sz = div_up(jbgp.nb_ic, min_chunk_sz); |
387 | float reduce_work = 0.5f * num_min_chunk_sz * jbgp.nb_os |
388 | + (float)num_min_chunk_sz / jbgp.nb_oc + 0.5f; |
389 | const int reduce_thr_groups = jbgp.nb_ic >= 1024 ? 8 : 4; |
390 | jbgp.nthr_ic_b |
391 | = saturate(1, nstl::min(reduce_thr_groups, num_min_chunk_sz), |
392 | int(reduce_work)); |
393 | jbgp.nthr_ic_b = nstl::min(jbgp.nthr_ic_b, jbgp.nthr); |
394 | if (jbgp.nthr_ic_b > 1) { |
395 | jbgp.nb_ic_blocking = div_up(jbgp.nb_ic, jbgp.nthr_ic_b); |
396 | jbgp.nb_ic_blocking /= div_up(jbgp.nb_ic_blocking, 64); |
397 | } |
398 | jbgp.gemm_batch_size = jbgp.nb_ic_blocking; |
399 | jbgp.K = jbgp.ic_block; |
400 | } else { |
401 | // Note: Here, ic divided into K_blocks of gemm_batch |
402 | const int ic_blks_per_k = div_up(k_blk, jbgp.ic_block); |
403 | const int nb_k_blk = div_up(jbgp.ic, k_blk); |
404 | const int max_nb_k_blocking = div_up(max_nb_ic_blocking, ic_blks_per_k); |
405 | int nb_k_blocking = max_div(nb_k_blk, max_nb_k_blocking); |
406 | const bool small_nb_k_blk = nb_k_blk <= max_nb_k_blocking; |
407 | if (small_nb_k_blk && nb_k_blocking == 1) |
408 | nb_k_blocking = max_nb_k_blocking; |
409 | |
410 | // For non small_nb_ic [i.e. that has nb_ic > 64] shape that has |
411 | // gcd(nb_ic, 64) < 16, we manually set nb_ic_blocking = 64 |
412 | // the coefficients 64 [used in max_nb_ic_blocking] and 16 are empirical |
413 | const int min_nb_k_blocking = small_nb_k_blk ? 1 : 16; |
414 | if (nb_k_blocking < min_nb_k_blocking) |
415 | nb_k_blocking = max_nb_k_blocking; |
416 | |
417 | jbgp.nb_ic_blocking = nb_k_blocking * ic_blks_per_k; |
418 | jbgp.K = k_blk; |
419 | jbgp.gemm_batch_size = nb_k_blocking; |
420 | } |
421 | |
422 | // to avoid cache concurrent write access from different threads |
423 | size_t sc_size = sizeof(brgemm_batch_element_t); |
424 | jbgp.adjusted_batch_size |
425 | = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size); |
426 | |
427 | if (is_amx_xf16 || jbgp.is_bf32) { |
428 | if (ip_fwd_adjust_thread_balance(jbgp)) { |
429 | // Adjust oc_block to improve thread balancing |
430 | jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp); |
431 | jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block); |
432 | jbgp.nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp, true); |
433 | |
434 | // Adjust os_block to improve thread balancing |
435 | if (jbgp.oc <= 16 |
436 | || types::data_type_size(jbgp.src_dt) * jbgp.mb * jbgp.ic |
437 | <= (size_t)platform::get_per_core_cache_size(2)) { |
438 | jbgp.os_block = get_os_block(jbgp, false, true); |
439 | jbgp.nb_os = div_up(jbgp.os, jbgp.os_block); |
440 | } |
441 | } |
442 | } |
443 | jbgp.use_buffer = (IMPLICATION(jbgp.dst_dt == jbgp.acc_dt, jbgp.with_sum)) |
444 | || (jbgp.nthr_ic_b > 1); |
445 | |
446 | // Configure matrix sizes |
447 | jbgp.M = jbgp.os_block; |
448 | jbgp.M_tail = jbgp.os % jbgp.os_block; |
449 | |
450 | jbgp.N = jbgp.oc_block; |
451 | jbgp.N_tail = jbgp.oc % jbgp.oc_block; |
452 | jbgp.K_tail = jbgp.use_buffer_a ? 0 : jbgp.ic % jbgp.K; |
453 | |
454 | jbgp.LDA = jbgp.use_buffer_a ? jbgp.K * jbgp.gemm_batch_size |
455 | : jbgp.ic_without_padding; |
456 | jbgp.LDB = jbgp.N; |
457 | jbgp.LDD = jbgp.oc_without_padding; |
458 | jbgp.LDC = (jbgp.use_buffer && jbgp.nthr_ic_b == 1) ? jbgp.N : jbgp.LDD; |
459 | |
460 | if (jbgp.is_bf32) { |
461 | const float M = static_cast<float>(jbgp.M); |
462 | const float N = nstl::min<float>(jbgp.N, jbgp.oc); |
463 | const float K |
464 | = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.ic); |
465 | const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32); |
466 | // TODO: Adjust blocking such that bigger M, N, K are generated. |
467 | if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25)) |
468 | return status::unimplemented; |
469 | } |
470 | |
471 | return status::success; |
472 | } |
473 | |
474 | status_t init_ip_conf_bwd_d(jit_brgemm_primitive_conf_t &jbgp) { |
475 | const bool is_amx_xf16 |
476 | = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32; |
477 | const bool is_avx512_bf16 = jbgp.isa == avx512_core_bf16; |
478 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
479 | const bool is_bf16 = everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt); |
480 | jbgp.is_amx = is_amx_xf16; |
481 | |
482 | constexpr int amx_xf16_granularity = 2; |
483 | jbgp.use_buffer_a = is_amx_xf16 && jbgp.oc % amx_xf16_granularity != 0; |
484 | jbgp.use_buffer_b = true; |
485 | jbgp.ip_bwd_d_global_b_transpose = false; |
486 | |
487 | jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp); |
488 | |
489 | // Optimization: for small shape we avoid large ic_block |
490 | // Thinking of os, ic, and oc as three dimensions, the boundary for small |
491 | // shapes is heuristically chosen via the following constraints: |
492 | // os <= 128 && max(ic, oc) <= 2048 && min(ic, oc) <= 1000 |
493 | // |
494 | // TODO: Will the optimization be useful for bf16 data type |
495 | const bool avoid_max_ic_block = is_f32 && !jbgp.is_bf32 && jbgp.os <= 128 |
496 | && nstl::max(jbgp.ic, jbgp.oc) <= 2048 |
497 | && nstl::min(jbgp.ic, jbgp.oc) <= 1000; |
498 | jbgp.ic_block = !avoid_max_ic_block |
499 | && jbgp.ic >= (is_f32 && !jbgp.is_bf32 ? 512 : 64) |
500 | ? 64 |
501 | : jbgp.ic >= 32 ? 32 : 16; |
502 | |
503 | jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block); |
504 | jbgp.nb_ic_blocking = 1; |
505 | jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block); |
506 | |
507 | jbgp.os_block = get_os_block(jbgp, false, false); |
508 | |
509 | jbgp.nb_os = div_up(jbgp.os, jbgp.os_block); |
510 | jbgp.nb_os_blocking = 1; |
511 | int os_blocking_max = 2; |
512 | for (int bl = os_blocking_max; bl >= 1; bl--) |
513 | if (jbgp.nb_os % bl == 0) { |
514 | jbgp.nb_os_blocking = bl; |
515 | break; |
516 | } |
517 | |
518 | if (is_amx_xf16 || jbgp.is_bf32) { |
519 | const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
520 | const int work_amount = jbgp.nb_ic * os_chunks; |
521 | float wb_ratio = (float)work_amount / (float)jbgp.nthr; |
522 | if (wb_ratio != 1.f && wb_ratio < 2.f) { |
523 | jbgp.ic_block |
524 | = (jbgp.ic_block > 16) ? jbgp.ic_block / 2 : jbgp.ic_block; |
525 | jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block); |
526 | } |
527 | } |
528 | |
529 | jbgp.nb_oc_blocking = 1; |
530 | const int oc_chunk_max_size = 64; |
531 | for (int bl = oc_chunk_max_size; bl >= 1; bl--) |
532 | if (jbgp.nb_oc % bl == 0) { |
533 | jbgp.nb_oc_blocking = bl; |
534 | break; |
535 | } |
536 | |
537 | jbgp.nthr_oc_b = 1; |
538 | const int num_work_to_parallel = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking) |
539 | * div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
540 | // Use oc reduction if we have |
541 | // * very large output channels |
542 | // * small work amount available to each thread |
543 | if ((num_work_to_parallel < 2 * jbgp.nthr |
544 | || jbgp.oc > (is_bf16 || jbgp.is_bf32 ? 4096 : 1024))) { |
545 | const int min_chunck_sz = (is_avx512_bf16) ? 32 : 16; |
546 | const int num_min_chunk_sz = div_up(jbgp.nb_oc, min_chunck_sz); |
547 | float reduce_work = 0.5f * num_min_chunk_sz * jbgp.nb_os |
548 | + (float)num_min_chunk_sz / jbgp.nb_ic + 0.5f; |
549 | |
550 | // optimization for transformer_lt on CPX/SKX |
551 | const int max_nthr_oc_b |
552 | = (!is_amx_xf16 && !jbgp.is_bf32 && jbgp.oc > 32000) |
553 | ? jbgp.nthr / 2 |
554 | : 4; |
555 | jbgp.nthr_oc_b = saturate(1, nstl::min(max_nthr_oc_b, num_min_chunk_sz), |
556 | int(reduce_work)); |
557 | jbgp.nthr_oc_b = nstl::min(jbgp.nthr_oc_b, jbgp.nthr); |
558 | if (jbgp.nthr_oc_b > 1) { |
559 | jbgp.nb_oc_blocking = div_up(jbgp.nb_oc, jbgp.nthr_oc_b); |
560 | jbgp.nb_oc_blocking |
561 | /= div_up(jbgp.nb_oc_blocking, oc_chunk_max_size); |
562 | } |
563 | } |
564 | jbgp.gemm_batch_size = jbgp.nb_oc_blocking; |
565 | // to avoid cache concurrent write access from different threads |
566 | size_t sc_size = sizeof(brgemm_batch_element_t); |
567 | jbgp.adjusted_batch_size |
568 | = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size); |
569 | |
570 | jbgp.use_buffer = jbgp.src_dt != jbgp.acc_dt || jbgp.nthr_oc_b > 1; |
571 | |
572 | jbgp.M = jbgp.os_block; |
573 | jbgp.M_tail = jbgp.os % jbgp.os_block; |
574 | |
575 | jbgp.K = jbgp.oc_block; |
576 | jbgp.N = jbgp.ic_block; |
577 | jbgp.N_tail = jbgp.ic % jbgp.ic_block; |
578 | jbgp.K_tail = jbgp.use_buffer_a ? 0 : jbgp.oc % jbgp.oc_block; |
579 | |
580 | jbgp.LDA = jbgp.use_buffer_a ? jbgp.K * jbgp.nb_oc_blocking |
581 | : jbgp.oc_without_padding; |
582 | jbgp.LDB = jbgp.N; |
583 | jbgp.LDD = jbgp.ic_without_padding; |
584 | jbgp.LDC = jbgp.use_buffer && jbgp.nthr_oc_b == 1 ? jbgp.N : jbgp.LDD; |
585 | |
586 | if (jbgp.is_bf32) { |
587 | const float M = static_cast<float>(jbgp.M); |
588 | const float N = nstl::min<float>(jbgp.N, jbgp.ic); |
589 | const float K |
590 | = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.oc); |
591 | const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32); |
592 | // TODO: Adjust blocking such that bigger M, N, K are generated. |
593 | if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25)) |
594 | return status::unimplemented; |
595 | } |
596 | |
597 | return status::success; |
598 | } |
599 | |
600 | void thread_balance(const jit_brgemm_primitive_conf_t &j, int &nb_os_blocking_, |
601 | int &nb_oc_blocking_, int &nb_ic_blocking_, int &nthr_, int &nthr_mb_, |
602 | int &nthr_oc_b_, int &nthr_ic_b_) { |
603 | nthr_ = nthr_mb_ = nthr_oc_b_ = nthr_ic_b_ = 1; |
604 | nb_os_blocking_ = j.nb_os_blocking; |
605 | nb_oc_blocking_ = j.nb_oc_blocking; |
606 | nb_ic_blocking_ = j.nb_ic_blocking; |
607 | |
608 | const bool is_f32 = everyone_is(f32, j.src_dt, j.wei_dt, j.dst_dt); |
609 | const bool is_xf16 = one_of(j.src_dt, bf16, f16) && (j.src_dt == j.dst_dt); |
610 | |
611 | const int max_threads = j.nthr; |
612 | const int nthr = max_threads; |
613 | auto calc_mem_cost = [=](int nb_os_blocking, int nb_oc_blocking, |
614 | int nb_ic_blocking, int nthr_mb, int nthr_oc, |
615 | int nthr_ic) { |
616 | float src_size = static_cast<float>(j.ic) * j.mb; |
617 | float dst_size = static_cast<float>(j.oc) * j.mb; |
618 | float wei_size = static_cast<float>(j.ic) * j.oc; |
619 | int os_chunks = div_up(j.nb_os, nb_os_blocking); |
620 | int oc_chunks = div_up(j.nb_oc, nb_oc_blocking); |
621 | int ic_chunks = div_up(j.nb_ic, nb_ic_blocking); |
622 | |
623 | float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size; |
624 | |
625 | float oi_channels_ratio = 0; |
626 | if (is_xf16) { |
627 | oi_channels_ratio = ((j.oc > 3 * j.ic && os_chunks > 1) |
628 | || (os_chunks == 1 && j.ic > j.oc)) |
629 | ? src_size / dst_size |
630 | : dst_size / src_size; |
631 | } else { |
632 | oi_channels_ratio = src_size / dst_size; |
633 | } |
634 | |
635 | auto get_src_coef = [=]() { |
636 | if (is_f32) { |
637 | float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f); |
638 | src_coef *= types::data_type_size(j.src_dt); |
639 | src_coef *= 4 * saturate(1, 4, div_up(j.ic, 1024)); |
640 | if (wei_compensation_scale < 2.0f) |
641 | src_coef += sqrtf(2.0f / wei_compensation_scale); |
642 | return src_coef; |
643 | } |
644 | float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f); |
645 | src_coef *= 4 * types::data_type_size(j.src_dt); |
646 | if (wei_compensation_scale < 1.0f) src_coef *= 4.0f; |
647 | |
648 | return src_coef; |
649 | }; |
650 | |
651 | auto get_dst_coef = [=]() { |
652 | if (is_f32) { |
653 | float dst_coef = types::data_type_size(j.dst_dt) |
654 | * nstl::max(oi_channels_ratio, 1.0f); |
655 | return dst_coef; |
656 | } |
657 | |
658 | return 2 * types::data_type_size(j.dst_dt) |
659 | * nstl::max(oi_channels_ratio, 1.0f); |
660 | }; |
661 | |
662 | auto get_wei_coef = [=]() { |
663 | if (is_f32) { |
664 | return nstl::max( |
665 | 4.0f - j.mb / 2048 * wei_compensation_scale, 1.0f); |
666 | } |
667 | |
668 | // limit the range of coefficient values to have more stable behavior |
669 | // for extreme cases |
670 | const float low_limit = 1.0f; |
671 | const float upper_limit = 1024.0f; |
672 | return utils::saturate( |
673 | low_limit, upper_limit, wei_compensation_scale); |
674 | }; |
675 | |
676 | float src_tr = 0.0f; |
677 | if (j.use_buffer_a && !is_f32) { |
678 | int src_tr_oc_par_work = div_up(os_chunks, nthr_mb) |
679 | * div_up(ic_chunks, nthr_ic) * nb_ic_blocking; |
680 | src_tr = get_src_coef() * div_up(src_tr_oc_par_work, nthr_oc) |
681 | * nb_os_blocking * j.os_block * j.ic_block; |
682 | } |
683 | |
684 | float dst_tr = 0.0f; |
685 | if (j.use_buffer_b && !is_f32) { |
686 | int dst_tr_ic_par_work = div_up(os_chunks, nthr_mb) |
687 | * div_up(oc_chunks, nthr_oc) * nb_oc_blocking; |
688 | dst_tr = get_dst_coef() * div_up(dst_tr_ic_par_work, nthr_ic) |
689 | * nb_os_blocking * j.os_block * j.oc_block; |
690 | } |
691 | |
692 | float src_v = get_src_coef() * div_up(os_chunks, nthr_mb) |
693 | * div_up(ic_chunks, nthr_ic) * nb_os_blocking * j.os_block |
694 | * nb_ic_blocking * j.ic_block; |
695 | float dst_v = get_dst_coef() * div_up(os_chunks, nthr_mb) |
696 | * div_up(oc_chunks, nthr_oc) * nb_os_blocking * j.os_block |
697 | * nb_oc_blocking * j.oc_block; |
698 | |
699 | auto acc_dt_sz = types::data_type_size(j.acc_dt); |
700 | float wei_v = get_wei_coef() * acc_dt_sz * div_up(oc_chunks, nthr_oc) |
701 | * div_up(ic_chunks, nthr_ic) * nb_oc_blocking * j.oc_block |
702 | * nb_ic_blocking * j.ic_block; |
703 | |
704 | float wei_r = 0; |
705 | if (nthr_mb > 1) { |
706 | auto wei_dt_sz = types::data_type_size(j.wei_dt); |
707 | int wei_r_mb_par_work = div_up(oc_chunks, nthr_oc) |
708 | * div_up(ic_chunks, nthr_ic) * nb_oc_blocking |
709 | * nb_ic_blocking; |
710 | wei_r = get_wei_coef() * div_up(wei_r_mb_par_work, nthr_mb) |
711 | * j.oc_block * j.ic_block |
712 | * (wei_dt_sz |
713 | + (is_f32 ? div_up(j.os, 1024) : 1) * nthr_mb |
714 | * acc_dt_sz); |
715 | } |
716 | |
717 | return src_tr + dst_tr + src_v + dst_v + wei_v + wei_r; |
718 | }; |
719 | |
720 | float best_mem_cost = calc_mem_cost(nb_os_blocking_, nb_oc_blocking_, |
721 | nb_ic_blocking_, nthr_mb_, nthr_oc_b_, nthr_ic_b_); |
722 | |
723 | /* Set range of values for nb_oc_blocking/nb_ic_blocking parameters to try. |
724 | Use powers-of-2 values to avoid potential issues on converting to |
725 | blocked weights layout stage |
726 | */ |
727 | auto get_blk_values |
728 | = [](int max_blk_value, int init_blk, int dim_blk_limit) { |
729 | int val_1st = rnd_up_pow2(init_blk); |
730 | int val_end = nstl::min(max_blk_value, dim_blk_limit); |
731 | std::vector<int> values; |
732 | for (int val = val_1st; val <= val_end; val <<= 1) |
733 | values.push_back(val); |
734 | return values; |
735 | }; |
736 | |
737 | const int max_nb_oc_blocking_pow |
738 | = j.ip_bwd_w_local_buffers_for_input_tensors ? 4 : j.nb_oc_blocking; |
739 | auto nb_oc_blocking_values |
740 | = get_blk_values(max_nb_oc_blocking_pow, j.nb_oc_blocking, j.nb_oc); |
741 | const int max_nb_ic_blocking_pow |
742 | = j.ip_bwd_w_local_buffers_for_input_tensors ? 4 : j.nb_ic_blocking; |
743 | auto nb_ic_blocking_values |
744 | = get_blk_values(max_nb_ic_blocking_pow, j.nb_ic_blocking, j.nb_ic); |
745 | |
746 | /* find the best thread distribution with lowest memory cost */ |
747 | const int min_osb_chunk = is_f32 ? 32 : is_xf16 ? 8 : 1; |
748 | const int nthr_mb_max = nstl::min(nthr, div_up(j.nb_os, min_osb_chunk)); |
749 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
750 | int nb_os_blocking = j.nb_os_blocking; |
751 | int os_chunks = div_up(j.nb_os, nb_os_blocking); |
752 | if (os_chunks < nthr_mb) { |
753 | int coef = saturate(1, 4, 2 * j.mb / (j.oc + j.ic)); |
754 | int os_blocking_max = div_up(div_up(j.nb_os, coef), nthr_mb); |
755 | for (int bl = os_blocking_max; bl >= 1; bl--) |
756 | if (j.nb_os % bl == 0) { |
757 | nb_os_blocking = bl; |
758 | break; |
759 | } |
760 | } |
761 | |
762 | const int nthr_par = nthr / nthr_mb; |
763 | |
764 | for (auto nb_oc_blocking : nb_oc_blocking_values) { |
765 | int num_oc_chunks = div_up(j.nb_oc, nb_oc_blocking); |
766 | const int nthr_oc_b_max = nstl::min(nthr_par, num_oc_chunks); |
767 | for_(int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) |
768 | for (auto nb_ic_blocking : nb_ic_blocking_values) { |
769 | int num_ic_chunks = div_up(j.nb_ic, nb_ic_blocking); |
770 | |
771 | int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, num_ic_chunks); |
772 | float mem_cost = calc_mem_cost(nb_os_blocking, nb_oc_blocking, |
773 | nb_ic_blocking, nthr_mb, nthr_oc_b, nthr_ic_b); |
774 | if (mem_cost <= best_mem_cost) { |
775 | best_mem_cost = mem_cost; |
776 | nb_os_blocking_ = nb_os_blocking; |
777 | nb_oc_blocking_ = nb_oc_blocking; |
778 | nb_ic_blocking_ = nb_ic_blocking; |
779 | nthr_mb_ = nthr_mb; |
780 | nthr_oc_b_ = nthr_oc_b; |
781 | nthr_ic_b_ = nthr_ic_b; |
782 | } |
783 | } |
784 | } |
785 | } |
786 | |
787 | nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_; |
788 | } |
789 | |
790 | status_t init_ip_conf_bwd_w(jit_brgemm_primitive_conf_t &jbgp) { |
791 | const bool is_amx_xf16 |
792 | = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32; |
793 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
794 | const bool has_weights_buffer = jbgp.wei_dt != jbgp.acc_dt; |
795 | jbgp.is_amx = is_amx_xf16; |
796 | |
797 | const int amx_xf16_row = 64; |
798 | const bool big_ic_blk_ok |
799 | = is_f32 && jbgp.ic % (4 * jbgp.simd_w) == 0 && jbgp.mb <= 128; |
800 | jbgp.ic_block = big_ic_blk_ok && !is_amx_xf16 |
801 | ? 4 * jbgp.simd_w |
802 | : (is_amx_xf16 && has_weights_buffer) ? amx_xf16_row : jbgp.simd_w; |
803 | jbgp.ic_block_ext |
804 | = is_amx_xf16 || (jbgp.wei_dt == dnnl::impl::data_type::bf16) ? 32 |
805 | : 16; |
806 | |
807 | jbgp.oc_block = has_weights_buffer ? get_oc_block(jbgp) |
808 | : ip_fwd_get_adjusted_oc_block(jbgp); |
809 | jbgp.oc_block_ext = ip_fwd_get_adjusted_oc_block(jbgp); |
810 | |
811 | jbgp.os_block = get_os_block(jbgp, false, false); |
812 | jbgp.nb_os = div_up(jbgp.os, jbgp.os_block); |
813 | |
814 | jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block); |
815 | jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block); |
816 | jbgp.nb_oc_blocking = 1; |
817 | jbgp.nb_ic_blocking = jbgp.nb_ic % 2 ? 1 : 2; |
818 | |
819 | // Configure matrix sizes |
820 | jbgp.M = jbgp.ic_block; |
821 | jbgp.M_tail = jbgp.ic % jbgp.ic_block; |
822 | |
823 | jbgp.N = jbgp.oc_block; |
824 | jbgp.N_tail = jbgp.oc % jbgp.oc_block; |
825 | |
826 | constexpr int amx_xf16_granularity = 2; |
827 | // sanity check, must hold for transpose routines to work fine |
828 | assert(IMPLICATION(is_amx_xf16, jbgp.os_block % amx_xf16_granularity == 0)); |
829 | const bool do_rnd_os = is_amx_xf16 && jbgp.os % amx_xf16_granularity != 0; |
830 | |
831 | jbgp.K = jbgp.os_block; |
832 | jbgp.K_tail = (jbgp.os % jbgp.os_block) + (do_rnd_os ? 1 : 0); |
833 | |
834 | jbgp.nb_os_blocking = 1; |
835 | int os_blocking_max = (is_amx_xf16 && jbgp.nb_os >= 64) |
836 | ? (types::data_type_size(jbgp.src_dt) * jbgp.mb * jbgp.ic |
837 | < platform::get_per_core_cache_size(2)) |
838 | ? 8 |
839 | : 4 |
840 | : nstl::min(64, jbgp.nb_os); |
841 | |
842 | for (int bl = os_blocking_max; bl >= 1; bl--) |
843 | if (jbgp.nb_os % bl == 0) { |
844 | jbgp.nb_os_blocking = bl; |
845 | break; |
846 | } |
847 | |
848 | jbgp.use_buffer_a = true; |
849 | const bool is_oc_big_2_pow = jbgp.oc >= 512 && math::is_pow2(jbgp.oc); |
850 | const bool is_huge_oc = jbgp.oc >= 4 * 1024; |
851 | jbgp.use_buffer_b = jbgp.dst_dt != f32 || is_oc_big_2_pow || is_huge_oc; |
852 | const bool os_dim_dominating = jbgp.os >= 5 * (jbgp.ic + jbgp.oc); |
853 | const int big_nb_os_threshold = is_amx_xf16 ? 64 : 256; |
854 | jbgp.ip_bwd_w_local_buffers_for_input_tensors |
855 | = is_amx_xf16 && jbgp.nb_os >= big_nb_os_threshold; |
856 | jbgp.harness = os_dim_dominating && jbgp.nb_os >= big_nb_os_threshold |
857 | ? harness_mb_reduction |
858 | : harness_2d_reduction; |
859 | |
860 | int nb_os_blocking, nb_oc_blocking, nb_ic_blocking, nthr, nthr_mb, nthr_oc, |
861 | nthr_ic; |
862 | // Caution: thread_balance requires `use_buffer_a` and `use_buffer_b` |
863 | // fields of jbgp to be properly set |
864 | thread_balance(jbgp, nb_os_blocking, nb_oc_blocking, nb_ic_blocking, nthr, |
865 | nthr_mb, nthr_oc, nthr_ic); |
866 | |
867 | jbgp.nb_os_blocking = nb_os_blocking; |
868 | jbgp.nb_oc_blocking = nb_oc_blocking; |
869 | jbgp.nb_ic_blocking = nb_ic_blocking; |
870 | jbgp.nthr = nthr; |
871 | jbgp.nthr_mb = nthr_mb; |
872 | jbgp.nthr_oc_b = nthr_oc; |
873 | jbgp.nthr_ic_b = nthr_ic; |
874 | |
875 | jbgp.gemm_batch_size = jbgp.nb_os_blocking; |
876 | // to avoid cache concurrent write access from different threads |
877 | size_t sc_size = sizeof(brgemm_batch_element_t); |
878 | jbgp.adjusted_batch_size |
879 | = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size); |
880 | |
881 | jbgp.use_buffer = IMPLICATION(!has_weights_buffer, jbgp.nthr_mb > 1); |
882 | |
883 | jbgp.LDA = jbgp.K; |
884 | jbgp.LDB = (jbgp.use_buffer_b) ? jbgp.N * jbgp.nb_oc_blocking |
885 | : jbgp.oc_without_padding; |
886 | jbgp.LDC = jbgp.LDD = jbgp.N; |
887 | |
888 | if (jbgp.is_bf32) { |
889 | const float M = static_cast<float>(jbgp.M); |
890 | const float N = nstl::min<float>(jbgp.N, jbgp.oc); |
891 | const float K |
892 | = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.os); |
893 | const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32); |
894 | // TODO: Adjust blocking such that bigger M, N, K are generated. |
895 | if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25)) |
896 | return status::unimplemented; |
897 | } |
898 | |
899 | return status::success; |
900 | } |
901 | |
902 | size_t buf_dt_size(data_type_t dt, cpu_isa_t isa) { |
903 | const auto buf_dt = isa == avx512_core_fp16 && dt == data_type::f16 |
904 | ? data_type::f32 |
905 | : dt; |
906 | return types::data_type_size(buf_dt); |
907 | } |
908 | |
909 | status_t init_ip_conf(cpu_isa_t isa, jit_brgemm_primitive_conf_t &jbgp, |
910 | const inner_product_desc_t &ipd, memory_desc_t &src_md, |
911 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
912 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
913 | const memory_desc_wrapper src_d(&src_md); |
914 | const memory_desc_wrapper weights_d(&weights_md); |
915 | const memory_desc_wrapper dst_d(&dst_md); |
916 | |
917 | using namespace prop_kind; |
918 | if (!mayiuse(avx512_core) && !mayiuse(avx2_vnni_2)) |
919 | return status::unimplemented; |
920 | |
921 | int ndims = src_d.ndims(); |
922 | if (weights_d.ndims() != ndims || dst_d.ndims() != 2) |
923 | return status::unimplemented; |
924 | |
925 | jbgp = zero<decltype(jbgp)>(); |
926 | jbgp.ndims = ndims; |
927 | jbgp.isa = isa; |
928 | jbgp.prop_kind = ipd.prop_kind; |
929 | jbgp.ngroups = 1; |
930 | jbgp.mb = src_d.dims()[0]; |
931 | jbgp.os = jbgp.mb; |
932 | jbgp.oc_without_padding = dst_d.dims()[1]; |
933 | jbgp.oc = jbgp.oc_without_padding; |
934 | jbgp.ic_without_padding = src_d.dims()[1]; |
935 | jbgp.ic = jbgp.ic_without_padding; |
936 | jbgp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
937 | jbgp.ih = (ndims < 4) ? 1 : src_d.dims()[ndims - 2]; |
938 | jbgp.iw = (ndims < 3) ? 1 : src_d.dims()[ndims - 1]; |
939 | jbgp.od = jbgp.oh = jbgp.ow = 1; |
940 | jbgp.kd = (ndims == 5) ? weights_d.dims()[2] : 1; |
941 | jbgp.kh = (ndims < 4) ? 1 : weights_d.dims()[ndims - 2]; |
942 | jbgp.kw = (ndims < 3) ? 1 : weights_d.dims()[ndims - 1]; |
943 | jbgp.stride_d = jbgp.stride_h = jbgp.stride_w = 1; |
944 | |
945 | if (!everyone_is(1, jbgp.ow, jbgp.oh, jbgp.od)) |
946 | return status::unimplemented; |
947 | if (jbgp.kw != jbgp.iw || jbgp.kh != jbgp.ih || jbgp.kd != jbgp.id) |
948 | return status::unimplemented; |
949 | if (!everyone_is(1, jbgp.kw, jbgp.kh, jbgp.kd)) |
950 | return status::unimplemented; |
951 | |
952 | const int full_simd_w = 16; |
953 | jbgp.simd_w = full_simd_w; |
954 | |
955 | jbgp.with_bias |
956 | = pick_by_prop_kind(jbgp.prop_kind, ipd.bias_desc.format_kind, |
957 | format_kind::undef, ipd.diff_bias_desc.format_kind) |
958 | != format_kind::undef; |
959 | |
960 | jbgp.src_dt = src_d.data_type(); |
961 | jbgp.dst_dt = dst_d.data_type(); |
962 | jbgp.wei_dt = weights_d.data_type(); |
963 | jbgp.bia_dt = jbgp.with_bias |
964 | ? pick_by_prop_kind(jbgp.prop_kind, ipd.bias_desc.data_type, |
965 | data_type::undef, ipd.diff_bias_desc.data_type) |
966 | : data_type::undef; |
967 | jbgp.signed_input = one_of(isa, avx512_core_vnni, avx512_core_bf16) |
968 | && jbgp.src_dt == s8; |
969 | const bool is_int8 = one_of(jbgp.src_dt, u8, s8) && jbgp.wei_dt == s8; |
970 | const bool is_bf16 |
971 | = everyone_is(bf16, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt) |
972 | || pick_by_prop_kind(jbgp.prop_kind, |
973 | everyone_is(bf16, jbgp.src_dt, jbgp.wei_dt) |
974 | && jbgp.dst_dt == f32, |
975 | everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt) |
976 | && jbgp.src_dt == f32, |
977 | everyone_is(bf16, jbgp.src_dt, jbgp.dst_dt) |
978 | && jbgp.wei_dt == f32); |
979 | const bool is_f16 = everyone_is(f16, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt) |
980 | || pick_by_prop_kind(jbgp.prop_kind, |
981 | everyone_is(f16, jbgp.src_dt, jbgp.wei_dt) |
982 | && jbgp.dst_dt == f32, |
983 | everyone_is(f16, jbgp.wei_dt, jbgp.dst_dt) |
984 | && jbgp.src_dt == f32, |
985 | everyone_is(f16, jbgp.src_dt, jbgp.dst_dt) |
986 | && jbgp.wei_dt == f32); |
987 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
988 | jbgp.is_bf32 = is_f32 && attr.fpmath_mode_ == fpmath_mode::bf16 |
989 | && isa == avx512_core_amx; |
990 | |
991 | if (!IMPLICATION(is_int8, |
992 | one_of(isa, avx512_core_vnni, avx512_core_bf16, |
993 | avx512_core_amx))) |
994 | return status::unimplemented; |
995 | if (!IMPLICATION(is_bf16, |
996 | one_of(isa, avx2_vnni_2, avx512_core_bf16, avx512_core_amx))) |
997 | return status::unimplemented; |
998 | if (!IMPLICATION(is_f32, jbgp.is_bf32 || (isa == avx512_core))) |
999 | return status::unimplemented; |
1000 | if (!IMPLICATION(is_f16, |
1001 | one_of(isa, avx2_vnni_2, avx512_core_fp16, |
1002 | avx512_core_amx_fp16))) |
1003 | return status::unimplemented; |
1004 | |
1005 | if (!one_of(true, is_int8, is_bf16, is_f16, is_f32)) |
1006 | return status::unimplemented; |
1007 | if (is_int8) { |
1008 | jbgp.acc_dt = s32; |
1009 | jbgp.with_scales = true; |
1010 | } else |
1011 | jbgp.acc_dt = f32; |
1012 | |
1013 | // Dispatch small shapes to VNNI for better performance |
1014 | const bool is_amx_int8 |
1015 | = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8); |
1016 | const auto amx_row |
1017 | = static_cast<int32_t>(data_type_vnni_granularity(jbgp.src_dt)) |
1018 | * jbgp.simd_w; |
1019 | const auto max_size = is_amx_int8 ? 1024 : 512; |
1020 | const bool is_small_shapes |
1021 | = (jbgp.os <= 16 && jbgp.ic <= amx_row && jbgp.oc <= amx_row) |
1022 | || (jbgp.ic <= max_size && jbgp.oc <= max_size && jbgp.mb == 1 |
1023 | && jbgp.ic % amx_row != 0); |
1024 | if (one_of(jbgp.isa, avx512_core_amx, avx512_core_amx) && is_small_shapes) |
1025 | return status::unimplemented; |
1026 | |
1027 | auto set_or_check_tags = [&]() -> status_t { |
1028 | using namespace format_tag; |
1029 | format_tag_t desired_src_tag = pick(ndims - 2, nc, ncw, nchw, ncdhw); |
1030 | format_tag_t desired_dst_tag = nc; |
1031 | |
1032 | if (src_d.format_kind() == format_kind::any) { |
1033 | CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); |
1034 | jbgp.src_tag = desired_src_tag; |
1035 | } else { |
1036 | jbgp.src_tag |
1037 | = memory_desc_matches_one_of_tag(src_md, desired_src_tag); |
1038 | } |
1039 | |
1040 | if (dst_d.format_kind() == format_kind::any) { |
1041 | CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); |
1042 | jbgp.dst_tag = desired_dst_tag; |
1043 | } else { |
1044 | jbgp.dst_tag = memory_desc_matches_one_of_tag(dst_md, nc); |
1045 | } |
1046 | |
1047 | if (one_of(format_tag::undef, jbgp.src_tag, jbgp.dst_tag)) |
1048 | return status::unimplemented; |
1049 | |
1050 | if (jbgp.with_bias && bias_md.format_kind == format_kind::any) |
1051 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
1052 | |
1053 | jbgp.is_wei_layout_any = weights_d.format_kind() == format_kind::any; |
1054 | |
1055 | memory_desc_t want_wei_md = weights_md; |
1056 | jbgp.wei_tag = get_brgemm_ip_weights_tag(isa, jbgp, weights_md); |
1057 | if (jbgp.wei_tag == format_tag::undef) return status::unimplemented; |
1058 | CHECK(memory_desc_init_by_tag(want_wei_md, jbgp.wei_tag)); |
1059 | |
1060 | if (jbgp.signed_input) { |
1061 | want_wei_md.extra.flags = 0 |
1062 | | memory_extra_flags::compensation_conv_s8s8 |
1063 | | memory_extra_flags::scale_adjust; |
1064 | want_wei_md.extra.compensation_mask = (1 << 0); |
1065 | want_wei_md.extra.scale_adjust |
1066 | = platform::s8s8_weights_scale_factor(); |
1067 | if (weights_md.format_kind != format_kind::any |
1068 | && want_wei_md != weights_md) |
1069 | return status::unimplemented; |
1070 | } |
1071 | weights_md = want_wei_md; |
1072 | return status::success; |
1073 | }; |
1074 | |
1075 | jbgp.brg_type = brgemm_addr; |
1076 | jbgp.nthr = nthreads; |
1077 | |
1078 | jbgp.use_uker = true; |
1079 | jbgp.use_interleave_stores = jbgp.use_uker; |
1080 | if (jbgp.use_uker) |
1081 | jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
1082 | CHECK(set_or_check_tags()); |
1083 | CHECK(attr.set_default_formats(&dst_md)); |
1084 | |
1085 | switch (jbgp.prop_kind) { |
1086 | case forward_training: |
1087 | case forward_inference: |
1088 | CHECK(init_ip_conf_fwd(jbgp, attr, dst_d)); |
1089 | break; |
1090 | case backward_data: CHECK(init_ip_conf_bwd_d(jbgp)); break; |
1091 | case backward_weights: CHECK(init_ip_conf_bwd_w(jbgp)); break; |
1092 | default: assert(!"invalid prop_kind" ); return invalid_arguments; |
1093 | } |
1094 | |
1095 | return status::success; |
1096 | } |
1097 | |
1098 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
1099 | const jit_brgemm_primitive_conf_t &jbgp) { |
1100 | |
1101 | size_t sc_size = sizeof(brgemm_batch_element_t); |
1102 | size_t n_elems = (size_t)jbgp.nthr * jbgp.adjusted_batch_size; |
1103 | |
1104 | if (jbgp.brg_type == brgemm_addr) { |
1105 | scratchpad.book(key_brgemm_primitive_batch, n_elems, sc_size, 64); |
1106 | } |
1107 | if (jbgp.use_buffer) { |
1108 | size_t nelements = (size_t)jbgp.nthr * jbgp.LDC * jbgp.M; |
1109 | if (jbgp.prop_kind == dnnl_backward_weights |
1110 | && (jbgp.nthr_mb > 1 || jbgp.harness == harness_mb_reduction)) { |
1111 | const size_t n_reduction_buffers = jbgp.nthr_mb > 1 |
1112 | ? jbgp.nthr_mb - (jbgp.wei_dt == f32) |
1113 | : 1; |
1114 | const size_t num_ic_chunks |
1115 | = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking); |
1116 | const size_t num_oc_chunks |
1117 | = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking); |
1118 | nelements = (size_t)n_reduction_buffers * num_ic_chunks |
1119 | * num_oc_chunks * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking |
1120 | * jbgp.ic_block * jbgp.oc_block; |
1121 | } else if (jbgp.prop_kind == dnnl_backward_weights |
1122 | && jbgp.nthr_mb == 1) { |
1123 | nelements = (size_t)jbgp.nthr * jbgp.nb_ic_blocking * jbgp.ic_block |
1124 | * jbgp.nb_oc_blocking * jbgp.oc_block; |
1125 | } else if (jbgp.prop_kind == dnnl_backward_data && jbgp.nthr_oc_b > 1) { |
1126 | const int adj_buffers = (jbgp.src_dt == f32) ? 1 : 0; |
1127 | int n_reduction_buffers = jbgp.nthr_oc_b - adj_buffers; |
1128 | nelements = (size_t)n_reduction_buffers * jbgp.LDC * jbgp.os; |
1129 | } else if (one_of(jbgp.prop_kind, forward_training, forward_inference) |
1130 | && jbgp.nthr_ic_b > 1) { |
1131 | const bool |
1132 | = (jbgp.dst_dt == f32 && jbgp.with_sum); |
1133 | int n_reduction_buffers = jbgp.nthr_ic_b - !need_extra_buffer; |
1134 | nelements = (size_t)n_reduction_buffers * jbgp.oc * jbgp.os; |
1135 | } |
1136 | scratchpad.book(key_brgemm_primitive_buffer, nelements, |
1137 | types::data_type_size(jbgp.acc_dt)); |
1138 | } |
1139 | if (jbgp.use_buffer_a && jbgp.prop_kind == dnnl_backward_weights) { |
1140 | const dim_t num_ic_chunks_per_thread |
1141 | = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1142 | ? 1 |
1143 | : div_up(div_up(jbgp.nb_ic, jbgp.nb_ic_blocking), |
1144 | jbgp.nthr_ic_b); |
1145 | const dim_t num_os_chunks_per_thread |
1146 | = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1147 | ? 1 |
1148 | : div_up(div_up(jbgp.nb_os, jbgp.nb_os_blocking), jbgp.nthr_mb); |
1149 | const dim_t num_elems_per_thread = num_ic_chunks_per_thread |
1150 | * num_os_chunks_per_thread * jbgp.gemm_batch_size |
1151 | * jbgp.os_block * jbgp.ic_block * jbgp.nb_ic_blocking; |
1152 | scratchpad.book(key_brgemm_primitive_buffer_a, |
1153 | jbgp.nthr * num_elems_per_thread, |
1154 | buf_dt_size(jbgp.src_dt, jbgp.isa)); |
1155 | } else if (jbgp.use_buffer_a && jbgp.prop_kind == dnnl_backward_data) { |
1156 | scratchpad.book(key_brgemm_primitive_buffer_a, |
1157 | (size_t)jbgp.nthr * jbgp.os_block * jbgp.LDA, |
1158 | buf_dt_size(jbgp.dst_dt, jbgp.isa)); |
1159 | } else if (jbgp.use_buffer_a) { // FWD |
1160 | scratchpad.book(key_brgemm_primitive_buffer_a, |
1161 | (size_t)jbgp.nthr * jbgp.LDA * jbgp.os_block |
1162 | * jbgp.nb_os_blocking, |
1163 | buf_dt_size(jbgp.src_dt, jbgp.isa)); |
1164 | } |
1165 | |
1166 | if (jbgp.use_buffer_b && jbgp.prop_kind == dnnl_backward_weights) { |
1167 | int num_os_chunks_per_thread |
1168 | = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1169 | ? 1 |
1170 | : div_up(div_up(jbgp.nb_os, jbgp.nb_os_blocking), jbgp.nthr_mb); |
1171 | const dim_t num_elems_per_thread = num_os_chunks_per_thread |
1172 | * jbgp.gemm_batch_size * jbgp.os_block * jbgp.LDB; |
1173 | scratchpad.book(key_brgemm_primitive_buffer_b, |
1174 | (size_t)jbgp.nthr * num_elems_per_thread, |
1175 | buf_dt_size(jbgp.dst_dt, jbgp.isa)); |
1176 | } |
1177 | |
1178 | if (jbgp.use_buffer_b && jbgp.prop_kind == dnnl_backward_data) { |
1179 | auto size_B = (size_t)jbgp.LDB * rnd_up(jbgp.K, 2); |
1180 | |
1181 | if (!jbgp.ip_bwd_d_global_b_transpose) |
1182 | scratchpad.book(key_brgemm_primitive_buffer_b, |
1183 | (dim_t)jbgp.nthr * jbgp.gemm_batch_size * size_B, |
1184 | buf_dt_size(jbgp.wei_dt, jbgp.isa)); |
1185 | else |
1186 | scratchpad.book(key_brgemm_primitive_buffer_b, |
1187 | (dim_t)jbgp.nb_oc * jbgp.nb_ic * size_B, |
1188 | buf_dt_size(jbgp.wei_dt, jbgp.isa)); |
1189 | } |
1190 | |
1191 | if (jbgp.prop_kind == dnnl_backward_weights && jbgp.with_bias |
1192 | && (jbgp.bia_dt != f32 || jbgp.nthr_mb > 1)) { |
1193 | int nbuffers = jbgp.nthr_mb - (jbgp.bia_dt == f32); |
1194 | scratchpad.book(key_iprod_bias_bf16_convert_wsp, |
1195 | (size_t)nbuffers * jbgp.oc, types::data_type_size(jbgp.acc_dt)); |
1196 | } |
1197 | |
1198 | if (dnnl_thr_syncable() && jbgp.prop_kind == dnnl_backward_weights) |
1199 | scratchpad.book<simple_barrier::ctx_t>( |
1200 | key_conv_wei_bia_reduction_bctx, 1); |
1201 | |
1202 | if (jbgp.is_amx) |
1203 | scratchpad.book(key_conv_amx_tile_buffer, |
1204 | (size_t)jbgp.nthr * jbgp.amx_buf_size_per_thread, sizeof(char)); |
1205 | } |
1206 | |
1207 | } // namespace brgemm_inner_product_utils |
1208 | |
1209 | } // namespace x64 |
1210 | } // namespace cpu |
1211 | } // namespace impl |
1212 | } // namespace dnnl |
1213 | |