1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <float.h>
17
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/memory_tracking.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26#include "cpu/x64/injectors/injector_utils.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29#include "cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp"
30#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
31
32#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39using namespace dnnl::impl::prop_kind;
40using namespace dnnl::impl::utils;
41
42using namespace Xbyak;
43
44jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel(
45 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
46 const memory_desc_t &dst_md)
47 : jit_generator(jit_name(), nullptr, ker_code_size, true, avx512_core_bf16)
48 , jcp(ajcp)
49 , attr_(attr) {
50 if (jcp.with_eltwise || jcp.with_binary) {
51 using namespace binary_injector;
52 static constexpr bool preserve_gpr = true;
53 static constexpr bool preserve_vmm = false;
54 static constexpr size_t helper_vmm_idx = 31;
55 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
56 static constexpr bool use_exact_tail_scalar_bcast = true;
57
58 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
59 r14, r15, r12, preserve_gpr, preserve_vmm,
60 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
61 memory_desc_wrapper(dst_md), tail_size, k_load_dim_tail_mask,
62 use_exact_tail_scalar_bcast};
63 const static_params_t static_params {
64 this->param1, rhs_arg_static_params};
65
66 postops_injector_ = utils::make_unique<
67 injector::jit_uni_postops_injector_t<avx512_core>>(
68 this, jcp.post_ops, static_params);
69 }
70
71 if (!isa_has_bf16(jcp.isa))
72 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
73 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
74 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6);
75}
76
77void jit_avx512_core_bf16_1x1_conv_kernel::bcast_loop(int load_loop_blk) {
78 mov(aux1_reg_bcast_data, reg_bcast_data);
79 mov(aux_reg_bcast_data, reg_bcast_data);
80
81 mov(aux_reg_output_data, reg_output_data);
82 mov(aux_reg_store_buf, reg_store_buf);
83
84 mov(reg_bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
85
86 Label bcast_loop;
87 Label bcast_loop_tail;
88 Label long_tail;
89
90 cmp(reg_bcast_loop_iter, jcp.ur);
91 jl(bcast_loop_tail, T_NEAR);
92
93 L(bcast_loop);
94 {
95 assert(jcp.bcast_block % jcp.ur == 0);
96 int num_substeps = jcp.bcast_block / jcp.ur;
97 assert(num_substeps > 0 && num_substeps < 10);
98 for (int i = 0; i < num_substeps; i++) {
99 if (i + 1 == num_substeps) L(long_tail);
100 reduce_loop(load_loop_blk, jcp.ur, i, false);
101 if (i < num_substeps - 1) {
102 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
103
104 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
105 add(aux_reg_store_buf, jcp.bcast_loop_output_substep);
106 } else {
107 add(aux1_reg_bcast_data,
108 jcp.bcast_loop_bcast_step
109 - (num_substeps - 1)
110 * jcp.bcast_loop_bcast_substep);
111
112 add(aux_reg_output_data,
113 jcp.bcast_loop_output_step * jcp.typesize_out
114 - (num_substeps - 1)
115 * jcp.bcast_loop_output_substep);
116 add(aux_reg_store_buf,
117 jcp.bcast_loop_output_step * jcp.typesize_acc
118 - (num_substeps - 1)
119 * jcp.bcast_loop_output_substep);
120 }
121 sub(reg_bcast_loop_iter, jcp.ur);
122 }
123 cmp(reg_bcast_loop_iter, jcp.bcast_block);
124 jge(bcast_loop, T_NEAR);
125 }
126
127 L(bcast_loop_tail);
128 if (jcp.ur_tail) {
129 Label bcast_loop_tail_out;
130 if (jcp.ur_tail >= jcp.ur) {
131 cmp(reg_bcast_loop_iter, jcp.ur);
132 jge(long_tail, T_NEAR);
133 }
134 if (jcp.ur_tail % jcp.ur) {
135 cmp(reg_bcast_loop_iter, 0);
136 jle(bcast_loop_tail_out, T_NEAR);
137 reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true);
138 L(bcast_loop_tail_out);
139 }
140 }
141}
142
143static int vreg_accum_idx(
144 const int load_loop_blk, const int i_load, const int i_ur) {
145 int idx = i_ur * load_loop_blk + i_load;
146 assert(idx < 31);
147 return idx;
148}
149
150Address jit_avx512_core_bf16_1x1_conv_kernel::output_ptr(
151 const int i_load, const int i_ur) {
152 if (one_of(jcp.prop_kind, forward_training, forward_inference,
153 backward_data)) {
154 const bool is_output_layout_nxc = is_out_layout_nxc();
155 int i_load_shift = is_output_layout_nxc
156 ? jcp.load_block
157 : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
158 int i_ur_shift = is_output_layout_nxc ? jcp.load_dim : jcp.load_block;
159 int offset = (i_load * i_load_shift + i_ur * i_ur_shift)
160 * jcp.typesize_out;
161 return EVEX_compress_addr(aux_reg_output_data, offset);
162 } else
163 return ptr[aux_reg_output_data
164 + (i_load ? reg_output_stride * i_load
165 : 0) // TODO: Xbyak should allow 0 scale
166 + jcp.typesize_out * jcp.load_block * i_ur];
167}
168
169template <typename F>
170static void iterate(const int load_loop_blk, const int ur, const bool mask_tail,
171 const F &f) {
172 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
173 const bool mask_flag = (mask_tail && i_load + 1 == load_loop_blk);
174 for (int i_ur = 0; i_ur < ur; i_ur++)
175 f(mask_flag, i_load, i_ur);
176 }
177}
178template <typename F>
179static void iterate(const int load_loop_blk, const int ur, const F &f) {
180 iterate(load_loop_blk, ur, false, f);
181}
182
183void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops(
184 const int load_loop_blk, const int ur) {
185 if (jcp.with_eltwise || jcp.with_binary) {
186 injector_utils::vmm_index_set_t vmm_idxs;
187 if (jcp.with_binary) {
188 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
189 rhs_arg_params_tail;
190 const auto mask_tail = jcp.oc_without_padding % isa_simd_width_;
191 iterate(load_loop_blk, ur, mask_tail,
192 [&](const bool mask_flag, const int i_load,
193 const int i_ur) {
194 const int aux_output_l_off
195 = get_output_offset(i_load, i_ur);
196 const auto vmm_idx
197 = vreg_accum_idx(load_loop_blk, i_load, i_ur);
198 vmm_idxs.emplace(vmm_idx);
199
200 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
201 vmm_idx, aux_reg_output_data);
202 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
203 vmm_idx, aux_output_l_off);
204 if (mask_flag)
205 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
206 });
207 rhs_arg_params = rhs_arg_params_tail;
208 rhs_arg_params.vmm_tail_idx_.clear();
209
210 const injector_utils::register_preserve_guard_t register_guard(
211 this, {reg_tmp});
212 const size_t reg_guard_stack_occupied
213 = register_guard.stack_space_occupied();
214
215 mov(abi_param1,
216 EVEX_compress_addr(rsp,
217 reg_abi_param1_backup + reg_guard_stack_occupied));
218
219 Label postops_done;
220 if (mask_tail) {
221 Label postops_no_tail;
222 mov(reg_tmp,
223 ptr[rsp + reg_load_loop_work_off
224 + reg_guard_stack_occupied]);
225 cmp(reg_tmp, jcp.oc_block * load_loop_blk);
226 jge(postops_no_tail, T_NEAR);
227 postops_injector_->compute_vector_range(
228 vmm_idxs, rhs_arg_params_tail);
229 jmp(postops_done, T_NEAR);
230 L(postops_no_tail);
231 }
232 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
233 L(postops_done);
234
235 } else {
236 iterate(load_loop_blk, ur,
237 [&](const bool, const int i_load, const int i_ur) {
238 vmm_idxs.emplace(
239 vreg_accum_idx(load_loop_blk, i_load, i_ur));
240 });
241 postops_injector_->compute_vector_range(vmm_idxs);
242 }
243 }
244}
245
246void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop(
247 int load_loop_blk, int ur, int substep, bool wraparound) {
248 const bool load_layout_nxc = is_load_layout_nxc();
249 const bool bcast_layout_nxc = is_bcast_layout_nxc();
250 const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
251 const int load_dim_tail = jcp.load_dim % jcp.load_block;
252
253 auto vreg_load = [=](int i_load) {
254 int idx = ur * load_loop_blk + i_load;
255 assert(idx < 31);
256 return Zmm(idx);
257 };
258 auto ymm_store = [=]() { return Xbyak::Ymm(31); };
259 auto zmm_store = [=]() { return Xbyak::Zmm(31); };
260
261 const auto vreg_accum = [=](int i_load, int i_ur) {
262 return Zmm(vreg_accum_idx(load_loop_blk, i_load, i_ur));
263 };
264
265 auto bias_ptr = [=](int i_load) {
266 return EVEX_compress_addr(
267 reg_bias_data, jcp.typesize_bia * jcp.oc_block * i_load);
268 };
269
270 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
271 assert(i_ur < jcp.ur);
272 assert(i_reduce <= jcp.reduce_loop_unroll);
273 int offt;
274 if (one_of(jcp.prop_kind, forward_training, forward_inference,
275 backward_data)) {
276 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
277 const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim
278 : jcp.reduce_loop_unroll;
279 offt = (i_reduce == jcp.reduce_loop_unroll)
280 ? (jcp.bcast_dim + i_ur) * reduce_mul
281 : i_ur * reduce_mul + i_reduce;
282 } else {
283 if (jcp.uses_permw_transposition) {
284 int rmul = bcast_layout_nxc ? jcp.ngroups * jcp.ic
285 : jcp.ic_block;
286 offt = i_reduce * rmul + i_ur;
287 } else {
288 offt = (i_reduce / 2) * 2 * jcp.ic_block + 2 * i_ur;
289 }
290 }
291 return EVEX_compress_addr(
292 aux_reg_bcast_data, jcp.typesize_in * offt, bcast);
293 };
294
295 auto load_ptr = [=](int i_reduce, int i_load) {
296 int u0 = i_reduce % jcp.reduce_loop_unroll;
297 int u1 = i_reduce / jcp.reduce_loop_unroll;
298 int lmul = jcp.load_block
299 * (load_layout_nxc ? 1
300 : utils::rnd_up(
301 jcp.reduce_dim, jcp.reduce_block));
302 int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block;
303 int offt = i_load * lmul + u0 * rmul;
304 return EVEX_compress_addr(aux_reg_load_data,
305 u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt);
306 };
307
308 auto store_buffer_ptr = [=](int i_load, int i_ur) {
309 const bool is_output_layout_nxc = is_out_layout_nxc();
310 int i_load_shift
311 = jcp.load_block * (is_output_layout_nxc ? 1 : jcp.bcast_dim);
312 int i_ur_shift = is_output_layout_nxc ? jcp.load_dim : jcp.load_block;
313 int offset = (i_load * i_load_shift + i_ur * i_ur_shift)
314 * jcp.typesize_acc;
315 return EVEX_compress_addr(aux_reg_store_buf, offset);
316 };
317
318 auto init = [=]() {
319 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
320 for (int i_ur = 0; i_ur < ur; ++i_ur) {
321 auto r = vreg_accum(i_load, i_ur);
322 vpxord(r, r, r);
323 }
324 };
325
326 auto store = [=]() {
327 if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
328
329 auto preamble = [=]() {
330 auto preamble_read = [=](bool from_buf) {
331 for (int i_ur = 0; i_ur < ur; ++i_ur) {
332 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
333 auto r = vreg_accum(i_load, i_ur);
334 if (jcp.prop_kind == backward_weights)
335 vaddps(r, r, output_ptr(i_load, i_ur));
336 else {
337 bool mask_flag = load_dim_tail
338 && i_load + 1 == load_loop_blk;
339 auto zmm_prev_dst = Xbyak::Zmm(31);
340 auto zmm_prev_dst_masked = may_be_mask_zmm(
341 zmm_prev_dst, mask_flag, true);
342 if (from_buf) {
343 vmovups(zmm_prev_dst,
344 store_buffer_ptr(i_load, i_ur));
345 } else if (jcp.dst_dt == data_type::bf16) {
346 vpmovzxwd(zmm_prev_dst_masked,
347 output_ptr(i_load, i_ur));
348 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
349 } else {
350 vmovups(zmm_prev_dst_masked,
351 output_ptr(i_load, i_ur));
352 }
353 vaddps(r, zmm_prev_dst);
354 }
355 }
356 }
357 };
358 if (one_of(jcp.prop_kind, forward_training, forward_inference,
359 backward_data)) {
360 Label read_from_output;
361 Label read_done;
362
363 if (!jcp.with_sum) {
364 test(reg_reduce_pos_flag,
365 FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST
366 jnz(read_done, T_NEAR);
367 } else {
368 test(reg_reduce_pos_flag,
369 FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST
370 jnz(read_from_output, T_NEAR);
371 }
372 preamble_read(true);
373 jmp(read_done, T_NEAR);
374
375 L(read_from_output);
376 preamble_read(false);
377
378 L(read_done);
379 } else if (jcp.prop_kind == backward_weights) {
380 Label read_done;
381
382 test(reg_reduce_pos_flag,
383 FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST
384 jnz(read_done, T_NEAR);
385 preamble_read(false);
386
387 L(read_done);
388 }
389 };
390
391 const auto apply_bias = [=]() {
392 if (jcp.with_bias
393 && one_of(jcp.prop_kind, forward_training,
394 forward_inference)) {
395 auto zmm_bias = Xbyak::Zmm(31);
396 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
397 bool mask_flag
398 = load_dim_tail && i_load + 1 == load_loop_blk;
399 for (int i_ur = 0; i_ur < ur; ++i_ur) {
400 auto vreg_acc = vreg_accum(i_load, i_ur);
401 if (jcp.bia_dt == data_type::bf16) {
402 vpmovzxwd(
403 may_be_mask_zmm(zmm_bias, mask_flag, true),
404 bias_ptr(i_load));
405 vpslld(zmm_bias, zmm_bias, 16);
406 vaddps(vreg_acc, zmm_bias);
407 } else {
408 vaddps(may_be_mask_zmm(vreg_acc, mask_flag, true),
409 bias_ptr(i_load));
410 }
411 }
412 }
413 }
414 };
415
416 const auto apply_bias_and_postops = [=]() {
417 Label store_no_post_ops;
418
419 test(reg_reduce_pos_flag,
420 FLAG_REDUCE_LAST); // If Not FLAG_REDUCE_LAST
421 jz(store_no_post_ops, T_NEAR);
422
423 apply_bias();
424 apply_postops(load_loop_blk, ur);
425 L(store_no_post_ops);
426 };
427
428 auto store_output = [=](bool to_buf) {
429 if (to_buf) {
430 for (int i_ur = 0; i_ur < ur; ++i_ur) {
431 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
432 vmovups(store_buffer_ptr(i_load, i_ur),
433 vreg_accum(i_load, i_ur));
434 }
435 }
436 } else if (jcp.prop_kind == backward_weights
437 || jcp.dst_dt == data_type::f32) {
438 for (int i_ur = 0; i_ur < ur; ++i_ur) {
439 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
440 auto vreg_acc = vreg_accum(i_load, i_ur);
441 // for nxc_layout-bwd_w, weights are still padded and
442 // the output_ptr here can be uninitialized scratchpad.
443 // To ensure final output (after reduction) is zero
444 // padded, here we zero-pad output by omitting the mask.
445 bool mask_flag = (jcp.prop_kind != backward_weights)
446 && load_dim_tail && i_load + 1 == load_loop_blk;
447 vmovups(output_ptr(i_load, i_ur),
448 may_be_mask_zmm(vreg_acc, mask_flag, false));
449 }
450 }
451 } else if (jcp.dst_dt == data_type::bf16) {
452 if (isa_has_bf16(jcp.isa) && is_out_layout_nxc()) {
453 // Optimization: use single store instruction for pair
454 // of the nearest vectors along LOAD dimension
455 for (int i_ur = 0; i_ur < ur; i_ur++) {
456 int i_load = 0;
457 for (; i_load < rnd_dn(load_loop_blk, 2); i_load += 2) {
458 auto zmm = vreg_accum(i_load, i_ur);
459 auto zmm_next = vreg_accum(i_load + 1, i_ur);
460 vcvtne2ps2bf16(zmm, zmm_next, zmm);
461 bool mask_flag = load_dim_tail
462 && i_load + 2 == load_loop_blk;
463 vmovdqu16(output_ptr(i_load, i_ur),
464 may_be_mask_zmm(
465 zmm, mask_flag, false, true));
466 }
467 if (load_loop_blk % 2 != 0) {
468 auto zmm = vreg_accum(i_load, i_ur);
469 auto ymm = Ymm(zmm.getIdx());
470 vcvtneps2bf16(ymm, zmm);
471 vmovdqu16(output_ptr(i_load, i_ur),
472 may_be_mask_ymm(ymm, load_dim_tail));
473 }
474 }
475 } else if (isa_has_bf16(jcp.isa)) {
476 // Optimization: use single store instruction for pair
477 // of the nearest vectors along BCAST dimension
478 assert(!is_out_layout_nxc());
479 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
480 int n_2bf2ps = (ur / 2) * 2, i_ur = 0;
481 for (i_ur = 0; i_ur < n_2bf2ps; i_ur += 2) {
482 auto zmm = zmm_store();
483 vcvtne2ps2bf16(zmm, vreg_accum(i_load, i_ur + 1),
484 vreg_accum(i_load, i_ur));
485 vmovups(output_ptr(i_load, i_ur), zmm);
486 }
487 if (i_ur < ur) {
488 auto ymm = ymm_store();
489 vcvtneps2bf16(ymm, vreg_accum(i_load, i_ur));
490 vmovups(output_ptr(i_load, i_ur), ymm);
491 }
492 }
493 } else {
494 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
495 for (int i_ur = 0; i_ur < ur; ++i_ur) {
496 auto ymm = ymm_store();
497 bf16_emu_->vcvtneps2bf16(
498 ymm, vreg_accum(i_load, i_ur));
499 bool mask_flag = load_dim_tail
500 && i_load + 1 == load_loop_blk;
501 vmovdqu16(output_ptr(i_load, i_ur),
502 may_be_mask_ymm(ymm, mask_flag));
503 }
504 }
505 }
506 } else {
507 assert(!"unsupported destination type");
508 }
509 };
510
511 preamble();
512
513 apply_bias_and_postops();
514
515 if (jcp.prop_kind == backward_weights) {
516 store_output(false);
517 } else {
518 Label store_to_output;
519 Label store_done;
520
521 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); // If FLAG_REDUCE_LAST
522 jnz(store_to_output, T_NEAR);
523
524 store_output(true);
525 jmp(store_done, T_NEAR);
526
527 L(store_to_output);
528 store_output(false);
529
530 L(store_done);
531 }
532 };
533
534 auto fma_block_bwd_w = [=](bool is_tail) {
535 int n_reduce_tail = jcp.reduce_dim % jcp.reduce_loop_unroll;
536 int n_reduce
537 = is_tail && n_reduce_tail > 0 && !jcp.uses_permw_transposition
538 ? n_reduce_tail
539 : jcp.reduce_loop_unroll;
540 int bcast_count = 0;
541 int pipeline_length_max = 1;
542 if (isa_has_bf16(jcp.isa)) {
543 const int max_regs = 32;
544 const int regs_for_accum = ur * load_loop_blk;
545 const int regs_for_pipeline_total
546 = max_regs - regs_for_accum - jcp.uses_permw_transposition;
547 const int regs_for_pipeline_iter
548 = load_loop_blk + jcp.uses_permw_transposition;
549 assert(regs_for_pipeline_total >= regs_for_pipeline_iter);
550 pipeline_length_max = nstl::min(
551 regs_for_pipeline_total / regs_for_pipeline_iter,
552 n_reduce / 2);
553 }
554
555 const int pipeline = saturate(1, 4, pipeline_length_max);
556 auto zmm_prm = [=]() { return zmm_store(); };
557 auto get_load_start_idx = [=](int bcast_count) {
558 return pipeline * jcp.uses_permw_transposition
559 + (bcast_count % pipeline) * load_loop_blk;
560 };
561 auto pipeline_bcast_ptr = [=](int i_reduce, int i_ur, bool bcast,
562 int pipeline_idx) {
563 if (jcp.uses_permw_transposition) {
564 int offset = 64 * pipeline_idx + jcp.typesize_in * 2 * i_ur;
565 auto p = rsp + broadcast_space + offset;
566 return bcast ? zword_b[p] : ptr[p];
567 } else {
568 return bcast_ptr(i_reduce, i_ur, bcast);
569 }
570 };
571
572 auto get_load_mask = [=](int i_reduce) {
573 bool is_reduce_tail = jcp.reduce_loop_unroll % 2
574 && i_reduce + 2 >= jcp.reduce_loop_unroll;
575 return is_load_layout_nxc() || is_reduce_tail ? half_mask
576 : full_mask;
577 };
578 auto get_bcast_mask = [=](int i_reduce) {
579 bool is_reduce_tail = jcp.reduce_loop_unroll % 2
580 && i_reduce + 2 >= jcp.reduce_loop_unroll;
581 return is_bcast_layout_nxc() || is_reduce_tail ? half_mask
582 : full_mask;
583 };
584
585 if (jcp.uses_permw_transposition) {
586 mov(EVEX_compress_addr(rsp, perm_reg_offset), reg_reduce_pos_flag);
587 mov(reg_trans_tmp, dst_prm_table);
588 vmovups(zmm_prm(), ptr[reg_trans_tmp]);
589
590 for (; bcast_count < pipeline; bcast_count++) {
591 int i_reduce = 2 * bcast_count;
592 bool is_reduce_tail = jcp.reduce_loop_unroll % 2
593 && i_reduce + 2 >= jcp.reduce_loop_unroll;
594 int load_idx = get_load_start_idx(bcast_count);
595 Opmask bcast_mask = get_bcast_mask(i_reduce);
596 auto bcast_values = vreg_load(bcast_count);
597
598 vmovdqu16(bcast_values | bcast_mask | T_z,
599 bcast_ptr(i_reduce, 0, false));
600 if (is_bcast_layout_nxc() && !is_reduce_tail) {
601 // Reuse i_ur argument to shift back pointer
602 vmovdqu16(bcast_values | half_mask_hi,
603 bcast_ptr(i_reduce + 1, -jcp.ic_block, false));
604 }
605 vpermw(bcast_values, zmm_prm(), bcast_values);
606 vmovups(pipeline_bcast_ptr(i_reduce, 0, false, bcast_count),
607 bcast_values);
608
609 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
610 vmovdqu16(vreg_load(load_idx + i_load) | bcast_mask | T_z,
611 load_ptr(i_reduce, i_load));
612 if (is_load_layout_nxc() && !is_reduce_tail) {
613 vmovdqu16(vreg_load(load_idx + i_load) | half_mask_hi,
614 load_ptr(i_reduce + 1, i_load - 1));
615 }
616 vpermw(vreg_load(load_idx + i_load), zmm_prm(),
617 vreg_load(load_idx + i_load));
618 }
619 }
620 } else {
621 for (; bcast_count < pipeline; bcast_count++) {
622 int i_reduce = 2 * bcast_count;
623 int load_idx = get_load_start_idx(bcast_count);
624
625 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
626 auto vreg = vreg_load(load_idx + i_load);
627 bool mask_flag
628 = load_dim_tail && i_load + 1 == load_loop_blk;
629 vmovups(may_be_mask_zmm(vreg, mask_flag, true),
630 load_ptr(i_reduce, i_load));
631 }
632 }
633 }
634
635 int use_bcast_count = 0;
636 for (int i_reduce = 0; i_reduce < n_reduce; i_reduce += 2) {
637 int bcast_pl_idx = use_bcast_count % pipeline;
638 for (int i_ur = 0; i_ur < ur; ++i_ur) {
639 // TODO: try to enable jcp.expl_bcast version
640 if (jcp.expl_bcast && load_loop_blk > 1) {
641 vpbroadcastd(vreg_bcast,
642 pipeline_bcast_ptr(
643 i_reduce, i_ur, false, bcast_pl_idx));
644 }
645 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
646 int load_idx = get_load_start_idx(use_bcast_count) + i_load;
647 if (!isa_has_bf16(jcp.isa)) {
648 if (jcp.uses_permw_transposition) {
649 bool is_reduce_tail = jcp.reduce_loop_unroll % 2
650 && i_reduce + 2 >= jcp.reduce_loop_unroll;
651 Opmask load_mask = get_load_mask(i_reduce);
652 vmovdqu16(vreg_load(i_load) | load_mask | T_z,
653 load_ptr(i_reduce, i_load));
654 if (is_load_layout_nxc() && !is_reduce_tail) {
655 vmovdqu16(vreg_load(i_load) | half_mask_hi,
656 load_ptr(i_reduce + 1, i_load - 1));
657 }
658 vpermw(vreg_load(i_load), zmm_prm(),
659 vreg_load(i_load));
660 } else
661 vmovups(vreg_load(i_load),
662 load_ptr(i_reduce, i_load));
663 }
664 if (jcp.expl_bcast && load_loop_blk > 1) {
665 if (!isa_has_bf16(jcp.isa)) {
666 auto acc = vreg_accum(i_load, i_ur);
667 auto wei = vreg_load(i_load);
668 bf16_emu_->vdpbf16ps(acc, wei, vreg_bcast);
669 } else
670 vdpbf16ps(vreg_accum(i_load, i_ur),
671 vreg_load(load_idx), vreg_bcast);
672 } else {
673 if (!isa_has_bf16(jcp.isa)) {
674 vpbroadcastd(zmm_tmp2,
675 pipeline_bcast_ptr(
676 i_reduce, i_ur, false, 0));
677 auto acc = vreg_accum(i_load, i_ur);
678 auto wei = vreg_load(i_load);
679 bf16_emu_->vdpbf16ps(acc, wei, zmm_tmp2);
680 } else
681 vdpbf16ps(vreg_accum(i_load, i_ur),
682 vreg_load(load_idx),
683 pipeline_bcast_ptr(i_reduce, i_ur, true,
684 bcast_pl_idx));
685 }
686 }
687 }
688 use_bcast_count++;
689 if (bcast_count < div_up(n_reduce, 2)) {
690 int load_idx = get_load_start_idx(bcast_count);
691 int i_reduce = bcast_count * 2;
692
693 if (jcp.uses_permw_transposition) {
694 bool is_reduce_tail = jcp.reduce_loop_unroll % 2
695 && i_reduce + 2 >= jcp.reduce_loop_unroll;
696 Opmask bcast_mask = get_bcast_mask(i_reduce);
697 int bcast_pl_idx = bcast_count % pipeline;
698 auto bcast_values = vreg_load(bcast_pl_idx);
699
700 vmovdqu16(bcast_values | bcast_mask | T_z,
701 bcast_ptr(i_reduce, 0, false));
702 if (is_bcast_layout_nxc() && !is_reduce_tail) {
703 vmovdqu16(bcast_values | half_mask_hi,
704 bcast_ptr(i_reduce + 1, -jcp.ic_block, false));
705 }
706 vpermw(bcast_values, zmm_prm(), bcast_values);
707 vmovups(pipeline_bcast_ptr(
708 i_reduce, 0, false, bcast_pl_idx),
709 bcast_values);
710
711 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
712 vmovdqu16(
713 vreg_load(load_idx + i_load) | bcast_mask | T_z,
714 load_ptr(i_reduce, i_load));
715 if (is_load_layout_nxc() && !is_reduce_tail) {
716 vmovdqu16(
717 vreg_load(load_idx + i_load) | half_mask_hi,
718 load_ptr(i_reduce + 1, i_load - 1));
719 }
720 vpermw(vreg_load(load_idx + i_load), zmm_prm(),
721 vreg_load(load_idx + i_load));
722 }
723 } else {
724 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
725 vmovups(vreg_load(load_idx + i_load),
726 load_ptr(i_reduce, i_load));
727 }
728 }
729 bcast_count++;
730 }
731 }
732 if (jcp.uses_permw_transposition)
733 mov(reg_reduce_pos_flag, EVEX_compress_addr(rsp, perm_reg_offset));
734 };
735
736 auto fma_block_fwd_bwd_d = [=](bool is_tail) {
737 int n_reduce_tail = jcp.reduce_dim % jcp.reduce_loop_unroll;
738 int n_reduce = is_tail && n_reduce_tail > 0 ? n_reduce_tail
739 : jcp.reduce_loop_unroll;
740 const int reduce_step = 2;
741 for (int i_reduce = 0; i_reduce < n_reduce; i_reduce += 2) {
742 if (isa_has_bf16(jcp.isa)) {
743 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
744 vmovdqu16(vreg_load(i_load), load_ptr(i_reduce, i_load));
745 }
746 }
747 for (int i_ur = 0; i_ur < ur; ++i_ur) {
748 const bool need_safe_reduce_dim_load
749 = (i_reduce == rnd_dn(reduce_dim_tail, reduce_step))
750 && reduce_dim_tail % reduce_step;
751 if (jcp.expl_bcast && load_loop_blk > 1) {
752 Label reduce_load_done;
753 if (need_safe_reduce_dim_load) {
754 Label skip_tail_load;
755 cmp(reduce_loop_iter, i_reduce + reduce_step);
756 jge(skip_tail_load, T_NEAR);
757 vpbroadcastw(
758 vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
759 // clear duplciate high word
760 vpsrld(vreg_bcast, vreg_bcast, 16);
761 jmp(reduce_load_done, T_NEAR);
762 L(skip_tail_load);
763 }
764 vpbroadcastd(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
765 L(reduce_load_done);
766 }
767 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
768 if (!isa_has_bf16(jcp.isa)) {
769 vmovdqu16(
770 vreg_load(i_load), load_ptr(i_reduce, i_load));
771 }
772 if (jcp.expl_bcast && load_loop_blk > 1) {
773 if (!isa_has_bf16(jcp.isa)) {
774 auto acc = vreg_accum(i_load, i_ur);
775 auto wei = vreg_load(i_load);
776 bf16_emu_->vdpbf16ps(acc, wei, vreg_bcast);
777 } else
778 vdpbf16ps(vreg_accum(i_load, i_ur),
779 vreg_load(i_load), vreg_bcast);
780 } else {
781 if (!isa_has_bf16(jcp.isa)) {
782 Label reduce_load_done;
783 if (need_safe_reduce_dim_load) {
784 Label skip_tail_load;
785 cmp(reduce_loop_iter, i_reduce + reduce_step);
786 jge(skip_tail_load, T_NEAR);
787 vpbroadcastw(zmm_tmp2,
788 bcast_ptr(i_reduce, i_ur, false));
789 // clear duplciate high word
790 vpsrld(zmm_tmp2, zmm_tmp2, 16);
791 jmp(reduce_load_done, T_NEAR);
792 L(skip_tail_load);
793 }
794 vpbroadcastd(
795 zmm_tmp2, bcast_ptr(i_reduce, i_ur, false));
796 L(reduce_load_done);
797 auto acc = vreg_accum(i_load, i_ur);
798 auto wei = vreg_load(i_load);
799 bf16_emu_->vdpbf16ps(acc, wei, zmm_tmp2);
800 } else {
801 auto vreg_acc = vreg_accum(i_load, i_ur);
802 bool mask_flag = i_load + 1 == load_loop_blk
803 && load_dim_tail;
804 vreg_acc = may_be_mask_zmm(
805 vreg_acc, mask_flag, true);
806 if (need_safe_reduce_dim_load) {
807 Label reduce_load_done;
808 Label skip_tail_load;
809 cmp(reduce_loop_iter, i_reduce + reduce_step);
810 jge(skip_tail_load, T_NEAR);
811 vpbroadcastw(zmm_tmp2,
812 bcast_ptr(i_reduce, i_ur, false));
813 // clear duplicate high word
814 vpsrld(zmm_tmp2, zmm_tmp2, 16);
815 jmp(reduce_load_done, T_NEAR);
816 L(skip_tail_load);
817 vpbroadcastd(zmm_tmp2,
818 bcast_ptr(i_reduce, i_ur, false));
819 L(reduce_load_done);
820 vdpbf16ps(
821 vreg_acc, vreg_load(i_load), zmm_tmp2);
822 } else {
823 vdpbf16ps(vreg_acc, vreg_load(i_load),
824 bcast_ptr(i_reduce, i_ur, true));
825 }
826 }
827 }
828 }
829 }
830 }
831 };
832
833 auto fma_block = [=](bool is_tail) {
834 return (jcp.prop_kind == backward_weights)
835 ? fma_block_bwd_w(is_tail)
836 : fma_block_fwd_bwd_d(is_tail);
837 };
838
839 Label reduce_loop;
840 Label reduce_loop_tail;
841
842 mov(aux_reg_load_data, reg_load_data);
843
844 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
845 init();
846
847 mov(reduce_loop_iter, reg_reduce_loop_work);
848 Label reduce_loop_exit;
849 cmp(reduce_loop_iter, jcp.reduce_loop_unroll);
850 jl(reduce_loop_tail, T_NEAR);
851
852 L(reduce_loop);
853 {
854 fma_block(false);
855 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
856 add(aux_reg_load_data, jcp.reduce_loop_load_step);
857 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
858 cmp(reduce_loop_iter, jcp.reduce_loop_unroll);
859 jge(reduce_loop, T_NEAR);
860 }
861
862 L(reduce_loop_tail);
863 cmp(reduce_loop_iter, 0);
864 jle(reduce_loop_exit, T_NEAR);
865
866 fma_block(true);
867 L(reduce_loop_exit);
868 store();
869}
870
871void jit_avx512_core_bf16_1x1_conv_kernel::compute_diff_bias(
872 int load_loop_blk) {
873 if (IMPLICATION(jcp.with_bias, jcp.prop_kind != backward_weights)) return;
874 Label skip_diff_bias;
875 test(reg_reduce_pos_flag, FLAG_COMPUTE_BIAS);
876 jz(skip_diff_bias, T_NEAR);
877
878 auto vunit = Zmm(31);
879 auto vreg_prm = Zmm(30);
880
881 auto get_load_offset = [=](int i_reduce, int i_load) {
882 dim_t lmul
883 = jcp.load_block * (is_load_layout_nxc() ? 1 : jcp.reduce_dim);
884 dim_t rmul = (is_load_layout_nxc() ? jcp.load_dim : jcp.load_block);
885 return (i_load * lmul + i_reduce * rmul) * jcp.typesize_in;
886 };
887 auto load_ptr = [=](int i_reduce, int i_load, int offset = 0) {
888 return EVEX_compress_addr(
889 aux_reg_load_data, get_load_offset(i_reduce, i_load) + offset);
890 };
891 auto bias_ptr = [=](int i_load) {
892 return ptr[reg_bias_data + i_load * jcp.load_block * jcp.typesize_acc];
893 };
894
895 auto vreg_acc = [=](int i_load) { return Zmm(i_load); };
896 auto vreg_load = [=](int i_load) { return Zmm(load_loop_blk + i_load); };
897
898 auto compute_diff_bias_block = [=](bool is_tail) {
899 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
900 auto vacc = vreg_acc(i_load);
901 auto vload = vreg_load(i_load);
902 auto vload_masked = is_load_layout_nxc() || is_tail
903 ? vload | half_mask | T_z
904 : vload;
905 if (jcp.uses_permw_transposition) {
906 vmovdqu16(vload_masked, load_ptr(0, i_load));
907 if (is_load_layout_nxc() && !is_tail) {
908 const int shift_16_elems = 16 * jcp.typesize_in;
909 vmovdqu16(vload | half_mask_hi,
910 load_ptr(0, i_load, -shift_16_elems));
911 }
912 vpermw(vload, vreg_prm, vload);
913 } else {
914 vmovups(vload_masked, load_ptr(0, i_load));
915 }
916 if (!isa_has_bf16(jcp.isa))
917 bf16_emu_->vdpbf16ps(vacc, vload, vunit);
918 else
919 vdpbf16ps(vacc, vload, vunit);
920 }
921 };
922
923 auto reg_unit_val = reg_bcast_loop_iter.cvt16();
924 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
925 vpbroadcastw(vunit, reg_unit_val);
926
927 if (jcp.uses_permw_transposition) {
928 mov(reg_bcast_loop_iter, dst_prm_table);
929 vmovups(vreg_prm, ptr[reg_bcast_loop_iter]);
930 }
931 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
932 auto vacc = vreg_acc(i_load);
933 vpxord(vacc, vacc, vacc);
934 }
935
936 mov(aux_reg_load_data, reg_load_data);
937 mov(reduce_loop_iter, reg_reduce_loop_work);
938 const int reduce_step = 2;
939 Label reduce_loop, reduce_loop_tail, reduce_loop_exit;
940 cmp(reduce_loop_iter, reduce_step);
941 jl(reduce_loop_tail, T_NEAR);
942
943 L(reduce_loop);
944 {
945 compute_diff_bias_block(false);
946 add(aux_reg_load_data, get_load_offset(reduce_step, 0));
947 sub(reduce_loop_iter, reduce_step);
948 cmp(reduce_loop_iter, reduce_step);
949 jge(reduce_loop, T_NEAR);
950 }
951
952 L(reduce_loop_tail);
953 cmp(reduce_loop_iter, 0);
954 jle(reduce_loop_exit, T_NEAR);
955
956 compute_diff_bias_block(true);
957 L(reduce_loop_exit);
958
959 Label skip_reading;
960 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST
961 jnz(skip_reading, T_NEAR);
962
963 const int load_dim_tail = jcp.load_dim % jcp.load_block;
964 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
965 bool mask_flag = load_dim_tail && i_load + 1 == load_loop_blk;
966 auto vacc = vreg_acc(i_load);
967 vaddps(may_be_mask_zmm(vacc, mask_flag, true), vacc, bias_ptr(i_load));
968 }
969
970 L(skip_reading);
971 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
972 bool mask_flag = load_dim_tail && i_load + 1 == load_loop_blk;
973 auto vacc = vreg_acc(i_load);
974 vmovups(bias_ptr(i_load), may_be_mask_zmm(vacc, mask_flag, false));
975 }
976
977 L(skip_diff_bias);
978}
979
980void jit_avx512_core_bf16_1x1_conv_kernel::generate() {
981 preamble();
982
983 sub(rsp, stack_space_needed);
984 if (jcp.with_binary) {
985 const auto zeroed_reg = r15;
986 xor_(zeroed_reg, zeroed_reg);
987 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg);
988 mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), abi_param1);
989 }
990
991 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
992 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
993 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
994
995 mov(reg_trans_tmp.cvt32(), 0x0000ffff);
996 kmovd(half_mask, reg_trans_tmp.cvt32());
997 mov(reg_trans_tmp.cvt32(), 0xffffffff);
998 kmovd(full_mask, reg_trans_tmp.cvt32());
999 mov(reg_trans_tmp.cvt32(), 0xffff0000);
1000 kmovd(half_mask_hi, reg_trans_tmp.cvt32());
1001
1002 const int load_dim_tail
1003 = (one_of(jcp.prop_kind, forward_training, forward_inference)
1004 ? jcp.oc_without_padding
1005 : jcp.load_dim)
1006 % jcp.load_block;
1007 if (load_dim_tail) {
1008 mov(reg_trans_tmp.cvt32(), (1 << load_dim_tail) - 1);
1009 kmovw(k_load_dim_tail_mask, reg_trans_tmp.cvt32());
1010
1011 if (is_out_layout_nxc())
1012 mov(reg_trans_tmp.cvt32(),
1013 (1 << (load_dim_tail + jcp.load_block)) - 1);
1014 else {
1015 const auto half_mask = (1 << load_dim_tail) - 1;
1016 mov(reg_trans_tmp.cvt32(), ((half_mask << 16) + half_mask));
1017 }
1018
1019 kmovd(k_load_dim_tail_mask_extended, reg_trans_tmp.cvt32());
1020 }
1021
1022 if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
1023
1024 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
1025 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
1026 mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
1027 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
1028
1029 if (one_of(jcp.prop_kind, backward_data, forward_training,
1030 forward_inference)) {
1031 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
1032 mov(reg_store_buf, ptr[param1 + GET_OFF(store_buffer)]);
1033 }
1034 if (jcp.prop_kind == backward_weights) {
1035 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
1036 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
1037 }
1038
1039 auto load_loop_body = [=](int load_loop_blk) {
1040 Label no_update_mask, update_mask_done;
1041 if (load_dim_tail) {
1042 cmp(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
1043 jge(no_update_mask, T_NEAR);
1044 kmovd(k_load_dim_mask, k_load_dim_tail_mask);
1045 kmovd(k_load_dim_mask_extended, k_load_dim_tail_mask_extended);
1046 jmp(update_mask_done, T_NEAR);
1047 L(no_update_mask);
1048 }
1049 kxnord(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask);
1050 kxnord(k_load_dim_mask_extended, k_load_dim_mask_extended,
1051 k_load_dim_mask_extended);
1052 L(update_mask_done);
1053
1054 mov(ptr[rsp + reg_load_loop_work_off], reg_load_loop_work);
1055 compute_diff_bias(load_loop_blk);
1056 bcast_loop(load_loop_blk);
1057 mov(reg_load_loop_work, ptr[rsp + reg_load_loop_work_off]);
1058
1059 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
1060 switch (jcp.prop_kind) {
1061 case forward_training:
1062 case forward_inference:
1063 add(reg_bias_data,
1064 load_loop_blk * jcp.load_block * jcp.typesize_bia);
1065 add(reg_output_data,
1066 load_loop_blk * jcp.load_block * jcp.typesize_out
1067 * (is_out_layout_nxc()
1068 ? 1
1069 : (jcp.with_dw_conv
1070 ? jcp.ow
1071 : jcp.bcast_dim)));
1072 add(reg_store_buf,
1073 load_loop_blk * jcp.load_block * jcp.typesize_acc
1074 * (is_out_layout_nxc() ? 1 : jcp.bcast_dim));
1075 if (jcp.with_binary) {
1076 const auto oc_off_oprnd = rcx;
1077 mov(oc_off_oprnd,
1078 EVEX_compress_addr(
1079 rsp, reg_binary_post_op_acc_off));
1080 add(oc_off_oprnd, jcp.load_block * load_loop_blk);
1081 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off),
1082 oc_off_oprnd);
1083 }
1084 break;
1085 case backward_data:
1086 add(reg_output_data,
1087 load_loop_blk * jcp.load_block * jcp.typesize_out
1088 * (is_out_layout_nxc() ? 1 : jcp.bcast_dim));
1089 add(reg_store_buf,
1090 load_loop_blk * jcp.load_block * jcp.typesize_acc
1091 * (is_out_layout_nxc() ? 1 : jcp.bcast_dim));
1092 break;
1093 case backward_weights:
1094 for (int i_load = 0; i_load < load_loop_blk; i_load++)
1095 add(reg_output_data, reg_output_stride);
1096 add(reg_bias_data,
1097 load_loop_blk * jcp.load_block * jcp.typesize_acc);
1098 break;
1099 default: assert(!"invalid prop_kind");
1100 }
1101 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
1102 };
1103
1104 const int simd_w = 16;
1105
1106 Label load_loop_blk[7];
1107
1108 int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32};
1109 int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32};
1110 if (jcp.prop_kind == backward_weights)
1111 for (int i = 1; i < 6; i++)
1112 ur_cases_fma_expl_bcast[i] /= 2;
1113
1114 const int size_ur_cases_fma = jcp.expl_bcast
1115 ? sizeof(ur_cases_fma_expl_bcast)
1116 : sizeof(ur_cases_fma_embd_bcast);
1117 const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast
1118 : ur_cases_fma_embd_bcast;
1119 const int *ur_cases = ur_cases_fma;
1120 const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
1121
1122 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
1123 int label_idx = num_ur_cases - ur_idx - 1;
1124 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
1125 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
1126 jle(load_loop_blk[label_idx], T_NEAR);
1127 }
1128 }
1129
1130 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
1131 int label_idx = num_ur_cases - ur_idx - 1;
1132 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
1133 L(load_loop_blk[label_idx]);
1134 {
1135 if (label_idx == 0) {
1136 cmp(reg_load_loop_work, 0);
1137 jle(load_loop_blk[num_ur_cases], T_NEAR);
1138 }
1139 load_loop_body(label_idx + 1);
1140 if (label_idx - 1 > 0) {
1141 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
1142 je(load_loop_blk[label_idx - 1], T_NEAR);
1143 }
1144 cmp(reg_load_loop_work, label_idx * simd_w);
1145 jg(load_loop_blk[label_idx]);
1146 }
1147 for (int idx = label_idx - 1; idx >= 0; --idx) {
1148 cmp(reg_load_loop_work, simd_w * (idx + 1));
1149 jge(load_loop_blk[idx], T_NEAR);
1150 }
1151 if (ur_idx < num_ur_cases - 2) {
1152 cmp(reg_load_loop_work, simd_w);
1153 jle(load_loop_blk[0], T_NEAR);
1154 }
1155 }
1156 }
1157 L(load_loop_blk[num_ur_cases]);
1158
1159 add(rsp, stack_space_needed);
1160
1161 postamble();
1162
1163 if (jcp.with_eltwise) postops_injector_->prepare_table();
1164
1165 if (jcp.prop_kind == backward_weights) {
1166 const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20,
1167 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
1168 29, 14, 30, 15, 31};
1169
1170 align(64);
1171 L(dst_prm_table);
1172 for (int i = 0; i < 32; ++i)
1173 dw(dst_prm_array[i]);
1174 }
1175}
1176
1177status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf(
1178 jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
1179 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1180 const memory_desc_wrapper &dst_d, primitive_attr_t &attr, int nthreads,
1181 bool reduce_src) {
1182 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1183 const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
1184 const int ndims = src_d.ndims();
1185
1186 jcp.nthr = nthreads;
1187 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
1188 : bf16_emulation_t::get_isa();
1189 jcp.prop_kind = cd.prop_kind;
1190
1191 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1192 jcp.mb = src_d.dims()[0];
1193
1194 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1195 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
1196 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1197 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
1198
1199 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1200 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
1201 jcp.iw = src_d.dims()[ndims - 1];
1202 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1203 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
1204 jcp.ow = dst_d.dims()[ndims - 1];
1205
1206 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1207 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1208 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1209
1210 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1211 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1212 jcp.l_pad = cd.padding[0][ndims - 3];
1213
1214 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1215 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1216 jcp.stride_w = cd.strides[ndims - 3];
1217
1218 jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
1219 format_kind::undef, cd.diff_bias_desc.format_kind)
1220 != format_kind::undef;
1221
1222 jcp.bia_dt = jcp.with_bias
1223 ? pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.data_type,
1224 data_type::undef, cd.diff_bias_desc.data_type)
1225 : data_type::undef;
1226 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
1227
1228 jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow;
1229 jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw;
1230
1231 const auto &post_ops = attr.post_ops_;
1232 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
1233 jcp.with_dw_conv = dw_conv_ind != -1;
1234 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
1235 // handled in depthwise convolution.
1236 const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind);
1237 jcp.with_sum = sum_ind != -1;
1238
1239 const int eltwise_ind
1240 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
1241 jcp.with_eltwise = eltwise_ind != -1;
1242 if (jcp.with_eltwise) {
1243 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1244 }
1245 const int binary_ind
1246 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
1247 jcp.with_binary = binary_ind != -1;
1248
1249 if (dw_conv_ind >= 0) {
1250 // dw_conv and post_ops after it are handled externally, so skip them
1251 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
1252 post_ops.entry_.cbegin() + dw_conv_ind);
1253 } else {
1254 jcp.post_ops = post_ops;
1255 }
1256
1257 using namespace injector;
1258 static constexpr bool sum_at_pos_0_only = true;
1259 static constexpr bool sum_requires_scale_one = true;
1260 static constexpr bool sum_requires_zp_zero = true;
1261 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1262 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1263 sum_requires_zp_zero});
1264 if (!post_ops_ok_) return status::unimplemented;
1265
1266 using namespace format_tag;
1267 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1268 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1269 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
1270 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
1271 bool is_data_layout_nxc
1272 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
1273 auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
1274 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1
1275 && (src_d.data_type() == data_type::f32
1276 || src_d.data_type() == data_type::bf16);
1277 if (ok_to_pad_channels) {
1278 jcp.oc = rnd_up(jcp.oc, simd_w);
1279 jcp.ic = rnd_up(jcp.ic, simd_w);
1280 }
1281
1282 const int is_bwd_d = jcp.prop_kind == backward_data;
1283 const int is_bwd_w = jcp.prop_kind == backward_weights;
1284
1285 auto wei_tag = utils::pick(
1286 2 * ndims - 6 + with_groups + 6 * is_bwd_d + 12 * is_bwd_w,
1287 OIw8i16o2i, gOIw8i16o2i, OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i,
1288 gOIdhw8i16o2i, IOw8o16i2o, gIOw8o16i2o, IOhw8o16i2o, gIOhw8o16i2o,
1289 IOdhw8o16i2o, gIOdhw8o16i2o, OIw16i16o, gOIw16i16o, OIhw16i16o,
1290 gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o);
1291 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1292
1293 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag
1294 && jcp.dst_tag == required_dat_tag && jcp.wei_tag == wei_tag;
1295 if (!args_ok) return status::unimplemented;
1296
1297 args_ok = true
1298 && IMPLICATION(!is_data_layout_nxc,
1299 jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0)
1300 && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0
1301 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1
1302 && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw
1303 && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0
1304 if (!args_ok) return status::unimplemented;
1305
1306 jcp.ic_block = jcp.oc_block = simd_w;
1307
1308 jcp.typesize_acc = sizeof(float);
1309 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
1310 jcp.typesize_in = types::data_type_size(src_d.data_type());
1311 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1312 jcp.dst_dt = dst_d.data_type();
1313 } else if (jcp.prop_kind == backward_data) {
1314 jcp.typesize_in = types::data_type_size(dst_d.data_type());
1315 jcp.typesize_out = types::data_type_size(src_d.data_type());
1316 jcp.dst_dt = src_d.data_type();
1317 } else if (jcp.prop_kind == backward_weights) {
1318 jcp.typesize_in = types::data_type_size(src_d.data_type());
1319 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
1320 jcp.dst_dt = weights_d.data_type();
1321 }
1322
1323 /* once all the formats are set, check the padding consistency */
1324 args_ok = true && jcp.ic <= src_d.padded_dims()[1]
1325 && jcp.oc <= dst_d.padded_dims()[1]
1326 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1327 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1328 if (!args_ok) return status::unimplemented;
1329
1330 const int SMALL_SPATIAL = 10;
1331 const int BIG_SPATIAL = 28;
1332
1333 const int BIG_REDUCE_DIM = 1024;
1334 const int BIG_LOAD_DIM = 256;
1335
1336 int load_blocking {0};
1337 int load_blocking_max {0};
1338 int bcast_blocking {0};
1339 int bcast_blocking_max {0};
1340 int reduce_blocking {0};
1341 int reduce_blocking_max {0};
1342
1343 jcp.load_grp_count = 1;
1344
1345 const int L1_capacity
1346 = platform::get_per_core_cache_size(1) / jcp.typesize_in;
1347 const int L2_size = platform::get_per_core_cache_size(2) / jcp.typesize_in;
1348 const int L2_capacity = (L2_size * 3) / 4;
1349
1350 if (one_of(jcp.prop_kind, forward_training, forward_inference,
1351 backward_data)) {
1352 jcp.nthr = nthreads;
1353 if (one_of(jcp.prop_kind, forward_inference, forward_training)) {
1354 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
1355 jcp.reduce_dim = jcp.ic;
1356 jcp.reduce_block = jcp.ic_block;
1357
1358 jcp.load_dim = jcp.oc;
1359 jcp.load_block = jcp.oc_block;
1360
1361 jcp.bcast_dim = jcp.is;
1362 } else {
1363 jcp.reduce_dim = jcp.oc;
1364 jcp.reduce_block = jcp.oc_block;
1365
1366 jcp.load_dim = jcp.ic;
1367 jcp.load_block = jcp.ic_block;
1368
1369 jcp.bcast_dim = jcp.os;
1370 }
1371 jcp.reduce_loop_unroll = jcp.reduce_block;
1372 jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll
1373 * (is_data_layout_nxc ? 1 : jcp.bcast_dim);
1374 jcp.reduce_loop_load_step
1375 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
1376 jcp.load_loop_load_step
1377 = utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)
1378 * jcp.load_block * jcp.typesize_in;
1379
1380 // adjusting registry blocking
1381 int max_regs, min_regs, size_treshold, ur_step;
1382 const int spatial
1383 = (one_of(jcp.prop_kind, forward_training, forward_inference))
1384 ? jcp.od * jcp.oh
1385 : jcp.id * jcp.ih;
1386 const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
1387 if (reduce_dim_tail % 2 == 0 // cannot expl_bcast odd tail
1388 && (8 * jcp.mb) / jcp.nthr >= 1) {
1389 max_regs = 9;
1390 min_regs = 6;
1391 size_treshold = 14;
1392 ur_step = 1;
1393 jcp.expl_bcast = true;
1394
1395 if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
1396 && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
1397 max_regs = 6;
1398 min_regs = isa_has_bf16(jcp.isa) ? 5 : 4;
1399 }
1400 } else {
1401 max_regs = 30;
1402 min_regs = 9;
1403 size_treshold = 14;
1404 ur_step = 1;
1405 jcp.expl_bcast = false;
1406 }
1407 jcp.ur = 1;
1408 if (!isa_has_bf16(jcp.isa)) {
1409 int adj_max_regs = max_regs / 3;
1410 max_regs = (adj_max_regs < min_regs) ? min_regs : adj_max_regs;
1411 }
1412 for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
1413 if ((spatial >= size_treshold && spatial % ur_w == 0)
1414 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
1415 jcp.ur = ur_w;
1416 break;
1417 }
1418 }
1419 if (jcp.ur == 1) {
1420 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
1421 int os_tail = jcp.os % max_regs;
1422 for (int i = max_regs; i >= min_regs; i -= ur_step) {
1423 int i_tail = jcp.os % i;
1424 if (i_tail > os_tail || i_tail == 0) {
1425 jcp.ur = i;
1426 os_tail = i_tail;
1427 if (i_tail == 0) break;
1428 }
1429 }
1430 }
1431
1432 jcp.bcast_block = jcp.ur;
1433 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
1434
1435 jcp.bcast_loop_output_step
1436 = jcp.ur * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block);
1437 jcp.bcast_loop_output_substep = -1; // unused
1438 jcp.bcast_loop_bcast_step = jcp.typesize_in * jcp.ur
1439 * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block);
1440 jcp.bcast_loop_bcast_substep = -1; // unused
1441
1442 jcp.load_loop_iter_step = jcp.load_block;
1443
1444 if (jcp.prop_kind == backward_data)
1445 jcp.loop_order = loop_lbr;
1446 else
1447 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
1448
1449 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1450 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1451 int nb_load = div_up(jcp.load_dim, jcp.load_block);
1452
1453 if (is_data_layout_nxc
1454 || (jcp.prop_kind == backward_data && reduce_src)) {
1455 reduce_blocking = jcp.reduce_dim;
1456 } else {
1457 if (jcp.expl_bcast) {
1458 if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
1459 && spatial < BIG_SPATIAL)
1460 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 160);
1461 else if (spatial > SMALL_SPATIAL)
1462 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 1024);
1463 else
1464 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 512);
1465 } else {
1466 reduce_blocking = nb_reduce;
1467 if (spatial <= SMALL_SPATIAL
1468 && jcp.reduce_dim >= BIG_REDUCE_DIM)
1469 reduce_blocking = 32;
1470 else if (spatial > SMALL_SPATIAL
1471 && jcp.reduce_dim >= BIG_REDUCE_DIM)
1472 reduce_blocking = 16;
1473 reduce_blocking
1474 = best_divider(nb_reduce, 1, reduce_blocking, true);
1475 reduce_blocking *= jcp.reduce_block;
1476 }
1477 // Check input data cache aliasing.
1478 // For other ISA constants may be updated.
1479 // 64 * 1024 is chosen due to 1MB L2 16-way cache.
1480 // 7 is empirical value. It is about half of 16.
1481 // So we leave about half of the set for other data - weights, dst
1482 int way_size = (64 * 1024) / jcp.typesize_in;
1483 int max_hits = 7;
1484 if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
1485 int nrb = reduce_blocking / simd_w;
1486 int sp = jcp.bcast_dim;
1487 int wl = way_size / simd_w;
1488 for (int start_off = 0; start_off < jcp.ur; start_off++) {
1489 for (int off = start_off, hits = 0; off < sp * nrb;
1490 off += wl) {
1491 if (off % sp >= jcp.ur || ++hits < max_hits) continue;
1492 int max_r_blocking
1493 = simd_w * nstl::max(1, (off + wl) / sp);
1494 reduce_blocking
1495 = nstl::min(reduce_blocking, max_r_blocking);
1496 break;
1497 }
1498 }
1499 }
1500 }
1501 load_blocking = jcp.load_dim;
1502
1503 int load_size = jcp.load_dim * jcp.reduce_dim;
1504 auto bcast_size
1505 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1506
1507 if (jcp.nthr <= 28 && jcp.mb < jcp.nthr
1508 && nb_load * nb_bcast > jcp.nthr) {
1509 // Some heuristic here
1510 float calc_koef = 0.01, best_cost = FLT_MAX;
1511 int n_lgc = jcp.nthr;
1512 float ratio = (float)load_size / (float)bcast_size;
1513 int best_lgc = ratio > 1 ? n_lgc : 1;
1514 auto calc_job_cost = [&](int lb, int tg, float mem_k) {
1515 int bb_size = jcp.mb * div_up(nb_bcast, tg);
1516 float calc_size = (float)(bb_size * jcp.ur)
1517 * (lb * jcp.load_block) * jcp.reduce_dim;
1518 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
1519 * jcp.reduce_dim;
1520 return calc_koef * calc_size + mem_k * mem_size;
1521 };
1522 for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
1523 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
1524 int min_lb = nb_load / lgc;
1525 int max_lb = div_up(nb_load, lgc);
1526 int min_tg = jcp.nthr / lgc;
1527 int max_tg = div_up(jcp.nthr, lgc);
1528 // Some heuristic here
1529 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
1530 float job_cost = 0.;
1531 if (jcp.nthr % lgc < nb_load % lgc) {
1532 job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
1533 } else {
1534 auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
1535 auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
1536 job_cost = nstl::max(job_cost1, job_cost2);
1537 }
1538
1539 if (job_cost < best_cost) {
1540 best_lgc = lgc;
1541 best_cost = job_cost;
1542 }
1543 }
1544 jcp.load_grp_count = best_lgc;
1545 load_blocking
1546 = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1547 } else {
1548 jcp.load_grp_count
1549 = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
1550 jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count,
1551 2 * jcp.load_grp_count, false);
1552 }
1553
1554 if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) {
1555 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
1556 } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr
1557 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
1558 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
1559 load_blocking = jcp.load_block;
1560 }
1561
1562 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
1563 div_up(jcp.nthr, jcp.load_grp_count))
1564 * jcp.bcast_block;
1565 bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking);
1566 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1567
1568 int space_for_bcast = (L2_capacity - /* kernel_size - */
1569 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
1570 - 3 * 1024);
1571 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2;
1572
1573 int bcast_in_cache
1574 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1575 bcast_blocking = nstl::min(
1576 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1577
1578 load_blocking_max = load_blocking;
1579 bcast_blocking_max = bcast_blocking * 3 / 2;
1580 reduce_blocking_max = reduce_blocking;
1581 } else if (jcp.prop_kind == backward_weights) {
1582 jcp.use_vmovntps = false;
1583 jcp.reduce_dim = jcp.is;
1584
1585 // cross-point for rn50 v1.5 blocked layout
1586 // in case of is_data_layout_nxc = true
1587 // jcp.uses_permw_transposition = false shows better performance for
1588 // the most problems according to performance measurements
1589 jcp.uses_permw_transposition = !is_data_layout_nxc && jcp.oh > 14
1590 && jcp.ow > 14
1591 // Performance improvement for i3d shapes
1592 && IMPLICATION(ndims == 5, jcp.oc <= 64);
1593
1594 if (jcp.uses_permw_transposition) {
1595 int rdim = nstl::min<dim_t>(256, jcp.reduce_dim);
1596 jcp.reduce_block = best_divider(jcp.reduce_dim, 7, rdim, true, 2);
1597 } else
1598 jcp.reduce_block = best_divider(jcp.reduce_dim, 8, 16, true, 2);
1599
1600 if (jcp.reduce_dim % jcp.reduce_block != 0) {
1601 jcp.reduce_block = best_divider(
1602 jcp.iw, 4, jcp.iw, jcp.uses_permw_transposition, 2);
1603 }
1604 if (jcp.reduce_block > 256) { jcp.reduce_block = 1; }
1605 if (!jcp.uses_permw_transposition)
1606 jcp.reduce_block = rnd_up(jcp.reduce_block, 2);
1607
1608 jcp.load_dim = jcp.oc;
1609 jcp.load_block = jcp.oc_block;
1610
1611 jcp.bcast_dim = jcp.ic;
1612 jcp.bcast_block = jcp.ic_block;
1613
1614 jcp.ur = jcp.bcast_block;
1615 jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block;
1616 // TODO: try to enable jcp.expl_bcast version
1617 jcp.expl_bcast = false;
1618
1619 jcp.reduce_loop_unroll = jcp.reduce_block;
1620 jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll
1621 * (is_data_layout_nxc && jcp.uses_permw_transposition
1622 ? jcp.ngroups * jcp.ic
1623 : jcp.ic_block);
1624 jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll
1625 * (is_data_layout_nxc && jcp.uses_permw_transposition
1626 ? jcp.ngroups * jcp.oc
1627 : jcp.oc_block);
1628
1629 jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block;
1630 jcp.bcast_loop_output_substep
1631 = jcp.oc_block * jcp.ur * jcp.typesize_out;
1632 jcp.bcast_loop_bcast_step = jcp.typesize_in * jcp.ic_block
1633 * (jcp.uses_permw_transposition
1634 ? (is_data_layout_nxc ? 1 : jcp.reduce_dim)
1635 : rnd_up(jcp.reduce_dim, 2));
1636 jcp.bcast_loop_bcast_substep = jcp.typesize_in * jcp.ur
1637 * (jcp.uses_permw_transposition ? 1 : 2);
1638 jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block
1639 * (jcp.uses_permw_transposition
1640 ? (is_data_layout_nxc ? 1 : jcp.os)
1641 : rnd_up(jcp.reduce_dim, 2));
1642 jcp.load_loop_iter_step = jcp.oc_block;
1643
1644 /* --- */
1645 balance(jcp, jcp.nthr);
1646
1647 load_blocking = div_up(jcp.load_dim, jcp.load_block);
1648 load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1649 load_blocking *= jcp.load_block;
1650
1651 load_blocking_max = load_blocking;
1652
1653 int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1654 int min_bcast_blocking = 5;
1655
1656 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1657 bcast_blocking = best_divider(
1658 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1659 bcast_blocking *= jcp.bcast_block;
1660 bcast_blocking_max = bcast_blocking;
1661 if (!is_data_layout_nxc) {
1662 assert(jcp.bcast_dim % bcast_blocking == 0);
1663 assert(jcp.load_dim % load_blocking == 0);
1664 }
1665
1666 // for reduction balance
1667 int max_reduce_blocking
1668 = nstl::min<dim_t>(L1_capacity / jcp.ur, jcp.reduce_dim);
1669 int min_reduce_blocking
1670 = nstl::min(L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1671 reduce_blocking = best_divider(
1672 jcp.reduce_dim, min_reduce_blocking, max_reduce_blocking, true);
1673 reduce_blocking = nstl::max(
1674 rnd_dn(reduce_blocking, jcp.reduce_block), jcp.reduce_block);
1675
1676 reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1677 } else
1678 return status::unimplemented;
1679
1680 assert(load_blocking);
1681 assert(load_blocking_max);
1682 assert(bcast_blocking);
1683 assert(bcast_blocking_max);
1684 assert(reduce_blocking);
1685 assert(reduce_blocking_max);
1686 if (!is_data_layout_nxc) {
1687 assert(load_blocking % jcp.load_block == 0);
1688 assert(load_blocking_max % jcp.load_block == 0);
1689 }
1690 assert(IMPLICATION(jcp.uses_permw_transposition,
1691 reduce_blocking % jcp.reduce_block == 0));
1692 assert(IMPLICATION(jcp.uses_permw_transposition,
1693 reduce_blocking_max % jcp.reduce_block == 0));
1694
1695 assert(jcp.bcast_block % jcp.ur == 0);
1696 assert(IMPLICATION(jcp.uses_permw_transposition,
1697 jcp.reduce_dim % jcp.reduce_block == 0));
1698
1699 jcp.nb_bcast_blocking = utils::div_up(bcast_blocking, jcp.bcast_block);
1700 jcp.nb_bcast_blocking_max
1701 = utils::div_up(bcast_blocking_max, jcp.bcast_block);
1702 jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block);
1703 jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block);
1704 jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block);
1705 jcp.nb_reduce_blocking_max
1706 = utils::div_up(reduce_blocking_max, jcp.reduce_block);
1707
1708 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1709 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1710 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1711
1712 /* adjust the thread decomposition
1713 * to improve the perf for small size problem
1714 * simply set the thread to max of nb_bcast and nb_load now
1715 * TODO: add get_thr_eff func to compute optimal thread
1716 * TODO: Threshold can be increase when init stride > 1 */
1717 auto bcast_size
1718 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1719 bool is_adjust_thread = one_of(jcp.prop_kind, forward_training,
1720 forward_inference, backward_data)
1721 && jcp.typesize_in * bcast_size < 8192 && jcp.ngroups < jcp.nthr
1722 && jcp.nb_bcast * jcp.nb_load < jcp.nthr;
1723 if (is_adjust_thread) {
1724 int nthr = nstl::max(jcp.nb_bcast, jcp.nb_load);
1725 jcp.nthr = nstl::min(jcp.nthr, nthr);
1726 }
1727
1728 return status::success;
1729}
1730
1731status_t jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad(
1732 memory_tracking::registrar_t &scratchpad,
1733 const jit_1x1_conv_conf_t &jcp) {
1734 using namespace dnnl::impl::memory_tracking::names;
1735 using namespace dnnl::impl::format_tag;
1736
1737 if (jcp.with_bias && jcp.oc_without_padding % jcp.oc_block
1738 && utils::one_of(jcp.prop_kind, forward_inference, forward_training,
1739 backward_weights)
1740 // nxc layout bias padding is only needed during bwd_wb prop
1741 && IMPLICATION(utils::one_of(jcp.dst_tag, nwc, nhwc, ndhwc),
1742 jcp.prop_kind == backward_weights)) {
1743 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
1744 }
1745 if (jcp.prop_kind == backward_weights) {
1746 const size_t wei_size = (size_t)jcp.ngroups
1747 * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block);
1748 const int n_wei_buffers
1749 = jcp.dst_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
1750 const size_t bias_size = (size_t)jcp.with_bias * jcp.ngroups
1751 * rnd_up(jcp.oc, jcp.oc_block);
1752 const int n_bias_buffers = jcp.with_bias
1753 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
1754 : jcp.nthr_mb - 1)
1755 : 0;
1756 const size_t wei_bia_size
1757 = wei_size * n_wei_buffers + bias_size * n_bias_buffers;
1758 scratchpad.book(key_conv_wei_reduction, wei_bia_size, jcp.typesize_acc);
1759
1760 if (!jcp.uses_permw_transposition) {
1761 const size_t dst_diff_tr_size_per_thr
1762 = (size_t)rnd_up(jcp.reduce_dim, 2) * jcp.oc_block
1763 * jcp.nb_load_blocking_max;
1764 scratchpad.book(key_conv_tr_diff_dst,
1765 jcp.nthr * dst_diff_tr_size_per_thr, jcp.typesize_in);
1766 const size_t src_tr_size_per_thr = (size_t)rnd_up(jcp.reduce_dim, 2)
1767 * jcp.ic_block * jcp.nb_bcast_blocking_max;
1768 scratchpad.book(key_conv_tr_src, jcp.nthr * src_tr_size_per_thr,
1769 jcp.typesize_in);
1770 }
1771 }
1772
1773 // TODO: Check - do we need this buffer for ALL cases?
1774 if (jcp.prop_kind != backward_weights) {
1775 const size_t grp_count = utils::div_up(
1776 jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count));
1777 const bool is_out_layout_nxc
1778 = (utils::one_of(jcp.prop_kind, prop_kind::forward_training,
1779 prop_kind::forward_inference)
1780 && utils::one_of(jcp.dst_tag, format_tag::ndhwc,
1781 format_tag::nhwc, format_tag::nwc))
1782 || (jcp.prop_kind == prop_kind::backward_data
1783 && utils::one_of(jcp.src_tag, format_tag::ndhwc,
1784 format_tag::nhwc, format_tag::nwc));
1785 const size_t max_load_per_thread = is_out_layout_nxc
1786 ? utils::rnd_up(jcp.load_dim, jcp.load_block)
1787 : rnd_up((utils::div_up(jcp.load_dim, grp_count)),
1788 jcp.load_block);
1789 const size_t store_buffer_size = (size_t)jcp.nthr
1790 * utils::rnd_up(jcp.bcast_dim, jcp.bcast_block)
1791 * max_load_per_thread;
1792 scratchpad.book(
1793 key_conv_store_wsp, store_buffer_size, jcp.typesize_acc);
1794 }
1795 // Heuristic threshold for requested scratchpad memory to avoid
1796 // possible crash on memory allocation.
1797 size_t scratchpad_limit_by_absolute_value = (size_t)20 * (1 << 30); // 20Gb
1798
1799 // Note: currently ignore this check for depthwise-fusion implementation
1800 // because of memory requirements in the latter.
1801 if (!jcp.with_dw_conv
1802 && scratchpad.size() > scratchpad_limit_by_absolute_value)
1803 return status::unimplemented;
1804
1805 return status::success;
1806}
1807
1808void jit_avx512_core_bf16_1x1_conv_kernel::balance(
1809 jit_1x1_conv_conf_t &jcp, int nthreads) {
1810 // initialize jcp reduction threading properties
1811 jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1812 if (nthreads < jcp.ngroups) {
1813 /* simplification... fortunately it doesn't hurt much */
1814 return;
1815 }
1816 const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1817 const int nb_load = div_up(jcp.load_dim, jcp.load_block);
1818 const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1819
1820 jcp.nthr_g = jcp.ngroups;
1821 const int nthr = nthreads / jcp.nthr_g;
1822
1823 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1824 /* calculate per thread memory cost (read/write). high level
1825 * optimizer tries to minimize memory consumption. few notes: (n1)
1826 * unclear why, but that essentially helps first convolution...
1827 * (n2) assuming the reduction over minibatch is always there:
1828 * - instead of 8 it should be 5 here (write ~= 2 read):
1829 * kernel: temporal workspace 1 write
1830 * reduction: 1 read from workspace and 1 write to the diff_wei
1831 * - but experiments showed 8 works better than 5 or 6... */
1832 int bcast_koeff = 1;
1833 int load_koeff = 1;
1834 int output_koeff = 12;
1835
1836 if (jcp.prop_kind == backward_weights) {
1837 int mult = (jcp.stride_h == 1 && jcp.stride_w == 1)
1838 ? nstl::max(1, (jcp.oc / jcp.ic))
1839 : 1;
1840 output_koeff = 4 * mult;
1841 }
1842
1843 return 0
1844 + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1845 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b)
1846 * jcp.ic_block * jcp.reduce_block / jcp.stride_h
1847 / jcp.stride_w /* (n1) */
1848 + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1849 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1850 * jcp.oc_block * jcp.reduce_block
1851 + (size_t)output_koeff /* (n2) */
1852 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1853 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block;
1854 };
1855
1856 int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1857 auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1858
1859 /* step 1: find the best thread distribution with lowest memory cost */
1860 const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1861 for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1862 const int nthr_par = nthr / nthr_mb;
1863 const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1864 for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1865 nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1866 auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1867 if (mem_cost <= best_mem_cost) {
1868 best_mem_cost = mem_cost;
1869 jcp.nthr_mb = nthr_mb;
1870 jcp.nthr_oc_b = nthr_oc_b;
1871 jcp.nthr_ic_b = nthr_ic_b;
1872 }
1873 }
1874 }
1875 if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1876 jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1877
1878 jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1879 assert(jcp.nthr <= nthreads);
1880}
1881} // namespace x64
1882} // namespace cpu
1883} // namespace impl
1884} // namespace dnnl
1885