1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <float.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory.hpp"
23#include "common/memory_tracking.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/platform.hpp"
29#include "cpu/x64/cpu_barrier.hpp"
30
31#include "cpu/x64/injectors/injector_utils.hpp"
32#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
33#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
34#include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp"
35#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
36
37#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
38
39namespace dnnl {
40namespace impl {
41namespace cpu {
42namespace x64 {
43
44using namespace dnnl::impl::format_tag;
45using namespace dnnl::impl::prop_kind;
46using namespace dnnl::impl::utils;
47
48using namespace Xbyak;
49
50jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel(
51 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
52 const memory_desc_t &dst_md)
53 : jit_generator(jit_name()), jcp(ajcp), attr_(attr) {
54 if (jcp.with_eltwise || jcp.with_binary) {
55 using namespace binary_injector;
56 static constexpr bool preserve_gpr = true;
57 static constexpr bool preserve_vmm = false;
58 static constexpr size_t helper_vmm_idx = 31;
59 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
60 static constexpr bool use_exact_tail_scalar_bcast = true;
61
62 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
63 r14, r15, r12, preserve_gpr, preserve_vmm,
64 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
65 memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask,
66 use_exact_tail_scalar_bcast};
67 const static_params_t static_params {
68 this->param1, rhs_arg_static_params};
69
70 postops_injector_ = utils::make_unique<
71 injector::jit_uni_postops_injector_t<avx512_core>>(
72 this, jcp.post_ops, static_params);
73 }
74}
75
76void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) {
77 mov(aux1_reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
78 mov(aux_reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
79
80 mov(aux_reg_output_data, reg_output_data);
81 mov(reg_bcast_loop_iter, EVEX_compress_addr(rsp, reg_bcast_loop_work_offt));
82
83 Label bcast_loop;
84 Label bcast_loop_tail;
85 Label large_tail;
86
87 cmp(reg_bcast_loop_iter, jcp.bcast_block);
88 jl(bcast_loop_tail, T_NEAR);
89
90 L(bcast_loop);
91 {
92 assert(jcp.bcast_block % jcp.ur == 0);
93 int num_substeps = jcp.bcast_block / jcp.ur;
94 assert(num_substeps > 0 && num_substeps < 10);
95 for (int i = 0; i < num_substeps; i++) {
96 if (i + 1 == num_substeps) L(large_tail);
97 reduce_loop(load_loop_blk, jcp.ur, i, false);
98 if (i < num_substeps - 1) {
99 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
100 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
101 } else {
102 add(aux1_reg_bcast_data,
103 jcp.bcast_loop_bcast_step
104 - (num_substeps - 1)
105 * jcp.bcast_loop_bcast_substep);
106 add(aux_reg_output_data,
107 jcp.bcast_loop_output_step
108 - (num_substeps - 1)
109 * jcp.bcast_loop_output_substep);
110 }
111 sub(reg_bcast_loop_iter, jcp.ur);
112 }
113 cmp(reg_bcast_loop_iter, jcp.bcast_block);
114 jge(bcast_loop, T_NEAR);
115 }
116
117 L(bcast_loop_tail);
118 if (jcp.ur_tail) {
119 Label bcast_loop_tail_out;
120 if (jcp.ur_tail >= jcp.ur) {
121 cmp(reg_bcast_loop_iter, jcp.ur);
122 jge(large_tail, T_NEAR);
123 }
124 if (jcp.ur_tail % jcp.ur) {
125 cmp(reg_bcast_loop_iter, 0);
126 jle(bcast_loop_tail_out, T_NEAR);
127 reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true);
128 L(bcast_loop_tail_out);
129 }
130 }
131}
132
133Address jit_avx512_common_1x1_conv_kernel::output_ptr(
134 const bool is_out_layout_nxc, const int i_load, const int i_ur) {
135 if (one_of(jcp.prop_kind, forward_training, forward_inference,
136 backward_data)) {
137 auto i_load_shift = is_out_layout_nxc
138 ? jcp.load_block
139 : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
140 int i_ur_shift = is_out_layout_nxc ? jcp.load_dim : jcp.load_block;
141 auto offset = (i_load * i_load_shift + i_ur * i_ur_shift)
142 * jcp.typesize_out;
143 return EVEX_compress_addr(aux_reg_output_data, offset);
144 } else
145 return ptr[aux_reg_output_data
146 + (i_load ? reg_output_stride * i_load
147 : 0) // TODO: Xbyak should allow 0 scale
148 + jcp.typesize_out * jcp.load_block * i_ur];
149}
150
151static int vreg_accum_idx(
152 const int load_loop_blk, const int i_load, const int i_ur) {
153 return (i_ur * load_loop_blk + i_load);
154}
155
156template <typename F>
157static void iterate(const int load_loop_blk, const int ur, const bool mask_tail,
158 const F &fun) {
159 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
160 const bool mask_flag = mask_tail && i_load + 1 == load_loop_blk;
161 for (int i_ur = 0; i_ur < ur; ++i_ur)
162 fun(mask_flag, i_load, i_ur);
163 }
164}
165template <typename F>
166static void iterate(const int load_loop_blk, const int ur, const F &fun) {
167 iterate(load_loop_blk, ur, false, fun);
168}
169
170void jit_avx512_common_1x1_conv_kernel::apply_postops(
171 const bool is_out_layout_nxc, const int load_loop_blk, const int ur) {
172 injector_utils::vmm_index_set_t vmm_idxs;
173 if (jcp.with_binary) {
174 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
175 const auto mask_tail = jcp.oc_without_padding % jcp.load_block;
176 iterate(load_loop_blk, ur, mask_tail,
177 [&](const bool mask_flag, const int i_load, const int i_ur) {
178 const auto vmm_idx
179 = vreg_accum_idx(load_loop_blk, i_load, i_ur);
180 vmm_idxs.emplace(vmm_idx);
181
182 rhs_arg_params.vmm_idx_to_out_reg.emplace(
183 vmm_idx, aux_reg_output_data);
184 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx,
185 get_output_offset(is_out_layout_nxc, i_load, i_ur));
186 if (mask_flag)
187 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
188 });
189
190 mov(abi_param1, ptr[rsp + reg_abi_param1_backup]);
191
192 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
193 } else {
194 iterate(load_loop_blk, ur,
195 [&](const bool, const int i_load, const int i_ur) {
196 vmm_idxs.emplace(
197 vreg_accum_idx(load_loop_blk, i_load, i_ur));
198 });
199 postops_injector_->compute_vector_range(vmm_idxs);
200 }
201}
202
203void jit_avx512_common_1x1_conv_kernel::reduce_loop(
204 int load_loop_blk, int ur, int substep, bool wraparound) {
205 const bool out_layout_nxc = is_out_layout_nxc(jcp);
206 const bool load_layout_nxc = is_load_layout_nxc(jcp);
207 const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp);
208 const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
209 const int load_dim_tail = jcp.load_dim % jcp.load_block;
210
211 auto vreg_load
212 = [=](int i_load) { return Zmm(ur * load_loop_blk + i_load); };
213
214 auto vreg_accum = [=](int i_load, int i_ur) {
215 return Zmm(vreg_accum_idx(load_loop_blk, i_load, i_ur));
216 };
217
218 auto bias_ptr = [=](int i_load) {
219 return EVEX_compress_addr(
220 reg_bias_data, jcp.typesize_out * jcp.oc_block * i_load);
221 };
222
223 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
224 assert(i_ur < jcp.ur);
225 assert(i_reduce <= jcp.reduce_loop_unroll);
226 dim_t offt;
227 if (one_of(jcp.prop_kind, forward_training, forward_inference,
228 backward_data)) {
229 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
230 const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim
231 : jcp.reduce_loop_unroll;
232 offt = (i_reduce == jcp.reduce_loop_unroll)
233 ? (jcp.bcast_dim + i_ur) * reduce_mul
234 : i_ur * reduce_mul + i_reduce;
235 } else {
236 int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block;
237 offt = i_reduce * rmul + i_ur;
238 }
239 return EVEX_compress_addr(
240 aux_reg_bcast_data, jcp.typesize_in * offt, bcast);
241 };
242
243 auto load_ptr = [=](int i_reduce, int i_load) {
244 int offt;
245 int u0 = i_reduce % jcp.reduce_loop_unroll;
246 int u1 = i_reduce / jcp.reduce_loop_unroll;
247 int lmul = jcp.load_block
248 * (load_layout_nxc ? 1
249 : utils::rnd_up(
250 jcp.reduce_dim, jcp.reduce_block));
251 int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block;
252 offt = i_load * lmul + u0 * rmul;
253 return EVEX_compress_addr(aux_reg_load_data,
254 u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt);
255 };
256
257 auto init = [=]() {
258 Label init_done;
259 Label init_zero;
260
261 if (jcp.with_bias
262 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
263 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
264 jz(init_zero, T_NEAR);
265
266 for (int i_load = 0; i_load < load_loop_blk; i_load++)
267 for (int i_ur = 0; i_ur < ur; ++i_ur) {
268 auto vreg_acc = vreg_accum(i_load, i_ur);
269 if (i_load + 1 == load_loop_blk && load_dim_tail)
270 vreg_acc = vreg_acc | k_load_dim_mask | T_z;
271
272 vmovups(vreg_acc, bias_ptr(i_load));
273 }
274 jmp(init_done, T_NEAR);
275 }
276
277 L(init_zero);
278 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
279 for (int i_ur = 0; i_ur < ur; ++i_ur) {
280 auto r = vreg_accum(i_load, i_ur);
281 vpxord(r, r, r);
282 }
283 L(init_done);
284 };
285
286 auto store = [=]() {
287 Label store_noadd;
288 if (!jcp.with_sum) {
289 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
290 jnz(store_noadd, T_NEAR);
291 }
292
293 for (int i_ur = 0; i_ur < ur; ++i_ur)
294 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
295 auto r = vreg_accum(i_load, i_ur);
296 if (i_load + 1 == load_loop_blk && load_dim_tail)
297 r = r | k_load_dim_mask | T_z;
298 vaddps(r, r, output_ptr(out_layout_nxc, i_load, i_ur));
299 }
300
301 L(store_noadd);
302
303 if (jcp.with_eltwise || jcp.with_binary) {
304 Label store_nopostops;
305 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
306 jz(store_nopostops, T_NEAR);
307
308 apply_postops(out_layout_nxc, load_loop_blk, ur);
309
310 L(store_nopostops);
311 }
312
313 auto store_output = [=](bool output_is_aligned) {
314 const auto mask_flag = load_dim_tail;
315 for (int i_ur = 0; i_ur < ur; ++i_ur) {
316 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
317 auto vreg_acc = vreg_accum(i_load, i_ur);
318 // for nxc_layout-bwd_w, weights are still padded and the
319 // output_ptr here can be uninitialized scratchpad.
320 // To ensure final output (after reduction) is zero-padded,
321 // here we zero-pad output by omitting the mask.
322 if (jcp.prop_kind != backward_weights
323 && (i_load + 1 == load_loop_blk && mask_flag)) {
324 vreg_acc = vreg_acc | k_load_dim_mask;
325 }
326 vmovups(output_ptr(out_layout_nxc, i_load, i_ur), vreg_acc);
327 }
328 }
329 };
330
331 Label unaligned_store, end_store;
332 test(aux_reg_output_data, cpu_isa_traits<avx512_core>::vlen - 1);
333 jnz(unaligned_store, T_NEAR);
334 store_output(true);
335 jmp(end_store, T_NEAR);
336 L(unaligned_store);
337 { store_output(false); }
338 L(end_store);
339 };
340
341 auto fma_block = [=](bool last_block) {
342 const int i_reduce_end = reduce_dim_tail && last_block
343 ? reduce_dim_tail
344 : jcp.reduce_loop_unroll;
345
346 for (int i_reduce = 0; i_reduce < i_reduce_end; i_reduce++) {
347 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
348 auto vreg = vreg_load(i_load);
349 if (i_load + 1 == load_loop_blk && load_dim_tail)
350 vreg = vreg | k_load_dim_mask | T_z;
351
352 vmovups(vreg, load_ptr(i_reduce, i_load));
353 }
354
355 for (int i_ur = 0; i_ur < ur; ++i_ur) {
356 if (jcp.expl_bcast && load_loop_blk > 1)
357 vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
358 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
359 auto vreg_acc = vreg_accum(i_load, i_ur);
360 if (i_load + 1 == load_loop_blk && load_dim_tail)
361 vreg_acc = vreg_acc | k_load_dim_mask | T_z;
362 if (jcp.expl_bcast && load_loop_blk > 1)
363 vfmadd231ps(vreg_acc, vreg_load(i_load), vreg_bcast);
364 else
365 vfmadd231ps(vreg_acc, vreg_load(i_load),
366 bcast_ptr(i_reduce, i_ur, true));
367 }
368 }
369 }
370 };
371
372 Label reduce_loop;
373 Label reduce_loop_tail;
374
375 mov(aux_reg_load_data, reg_load_data);
376
377 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
378 init();
379
380 mov(reduce_loop_iter, reg_reduce_loop_work);
381 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
382 jle(reduce_loop_tail, T_NEAR);
383
384 L(reduce_loop);
385 {
386 fma_block(false);
387 safe_add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step, reg_long_offt);
388 add(aux_reg_load_data, jcp.reduce_loop_load_step);
389 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
390 jg(reduce_loop, T_NEAR);
391 }
392
393 L(reduce_loop_tail);
394 fma_block(true);
395
396 store();
397}
398
399void jit_avx512_common_1x1_conv_kernel::generate() {
400 preamble();
401
402 sub(rsp, stack_space_needed);
403 if (jcp.with_binary) {
404 const auto zeroed_reg = r15;
405 xor_(zeroed_reg, zeroed_reg);
406 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg);
407 mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), param1);
408 }
409
410 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
411 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
412 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
413 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
414
415 if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
416
417 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
418 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
419 mov(EVEX_compress_addr(rsp, reg_bcast_loop_work_offt), reg_bcast_loop_work);
420 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
421 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
422 if (jcp.prop_kind == backward_weights)
423 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
424
425 const int load_dim_tail
426 = (one_of(jcp.prop_kind, forward_training, forward_inference)
427 ? jcp.oc_without_padding
428 : jcp.load_dim)
429 % jcp.load_block;
430 if (load_dim_tail) {
431 Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32();
432 mov(reg_tail_32, (1 << load_dim_tail) - 1);
433 kmovw(k_load_dim_tail_mask, reg_tail_32);
434 }
435
436 auto load_loop_body = [=](int load_loop_blk) {
437 if (load_dim_tail)
438 kxnorw(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask);
439 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
440 if (load_dim_tail) {
441 Label no_update_mask;
442 jge(no_update_mask, T_NEAR);
443 kmovw(k_load_dim_mask, k_load_dim_tail_mask);
444 L(no_update_mask);
445 }
446 bcast_loop(load_loop_blk);
447 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
448 switch (jcp.prop_kind) {
449 case forward_training:
450 case forward_inference:
451 add(reg_bias_data,
452 load_loop_blk * jcp.load_block * jcp.typesize_out);
453 safe_add(reg_output_data,
454 load_loop_blk * jcp.load_block * jcp.typesize_out
455 * (is_out_layout_nxc(jcp)
456 ? 1
457 : (jcp.with_dw_conv
458 ? jcp.ow
459 : jcp.bcast_dim)),
460 reg_long_offt);
461 if (jcp.with_binary) {
462 const auto oc_off_oprnd = aux_reg_load_data;
463 mov(oc_off_oprnd,
464 EVEX_compress_addr(
465 rsp, reg_binary_post_op_acc_off));
466 add(oc_off_oprnd, jcp.load_block * load_loop_blk);
467 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off),
468 oc_off_oprnd);
469 }
470 break;
471 case backward_data:
472 safe_add(reg_output_data,
473 load_loop_blk * jcp.load_block * jcp.typesize_out
474 * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim),
475 reg_long_offt);
476 break;
477 case backward_weights:
478 for (int i_load = 0; i_load < load_loop_blk; i_load++)
479 add(reg_output_data, reg_output_stride);
480 break;
481 default: assert(!"invalid prop_kind");
482 }
483 };
484
485 const int simd_w = 16;
486
487 Label load_loop_blk[7];
488
489 // with an implicit load_loop_block {6, 5, 4, 3, 2, 1}
490 static const int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32};
491 static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32};
492
493 const int size_ur_cases_fma = jcp.expl_bcast
494 ? sizeof(ur_cases_fma_expl_bcast)
495 : sizeof(ur_cases_fma_embd_bcast);
496
497 const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast
498 : ur_cases_fma_embd_bcast;
499 const int *ur_cases = ur_cases_fma;
500 const int num_ur_cases = size_ur_cases_fma / sizeof(*ur_cases);
501
502 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
503 int label_idx = num_ur_cases - ur_idx - 1;
504 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
505 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
506 jle(load_loop_blk[label_idx], T_NEAR);
507 }
508 }
509
510 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
511 int label_idx = num_ur_cases - ur_idx - 1;
512 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
513 L(load_loop_blk[label_idx]);
514 {
515 if (label_idx == 0) {
516 cmp(reg_load_loop_work, 0);
517 jle(load_loop_blk[num_ur_cases], T_NEAR);
518 }
519 load_loop_body(label_idx + 1);
520 if (label_idx - 1 > 0) {
521 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
522 je(load_loop_blk[label_idx - 1], T_NEAR);
523 }
524 cmp(reg_load_loop_work, label_idx * simd_w);
525 jg(load_loop_blk[label_idx]);
526 }
527 for (int idx = label_idx - 1; idx >= 0; --idx) {
528 cmp(reg_load_loop_work, simd_w * (idx + 1));
529 jge(load_loop_blk[idx], T_NEAR);
530 }
531 if (ur_idx < num_ur_cases - 2) {
532 cmp(reg_load_loop_work, simd_w);
533 jle(load_loop_blk[0], T_NEAR);
534 }
535 }
536 }
537 L(load_loop_blk[num_ur_cases]);
538
539 add(rsp, stack_space_needed);
540
541 postamble();
542
543 if (jcp.with_eltwise) postops_injector_->prepare_table();
544}
545
546status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
547 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
548 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
549 const primitive_attr_t &attr, int nthreads, bool reduce_src) {
550 if (!mayiuse(avx512_core)) return status::unimplemented;
551
552 if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(),
553 dst_d.data_type()))
554 return status::unimplemented;
555
556 jcp.nthr = nthreads;
557
558 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
559 const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
560 const int ndims = src_d.ndims();
561
562 jcp.prop_kind = cd.prop_kind;
563
564 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
565 jcp.mb = src_d.dims()[0];
566
567 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
568 jcp.oc = jcp.oc_without_padding;
569 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
570 jcp.ic = jcp.ic_without_padding;
571
572 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
573 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
574 jcp.iw = src_d.dims()[ndims - 1];
575 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
576 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
577 jcp.ow = dst_d.dims()[ndims - 1];
578
579 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
580 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
581 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
582
583 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
584 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
585 jcp.l_pad = cd.padding[0][ndims - 3];
586
587 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
588 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
589 jcp.stride_w = cd.strides[ndims - 3];
590
591 jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
592 format_kind::undef, cd.diff_bias_desc.format_kind)
593 != format_kind::undef;
594
595 jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow;
596 jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw;
597
598 const auto &post_ops = attr.post_ops_;
599 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
600 jcp.with_dw_conv = dw_conv_ind != -1;
601 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
602 // handled in depthwise convolution.
603 const int eltwise_ind
604 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
605 jcp.with_eltwise = eltwise_ind != -1;
606 if (jcp.with_eltwise) {
607 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
608 }
609
610 const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind);
611 jcp.with_sum = sum_ind != -1;
612
613 const int binary_ind
614 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
615 jcp.with_binary = binary_ind != -1;
616
617 if (dw_conv_ind >= 0) {
618 // dw_conv and post_ops after it are handled externally, so skip them
619 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
620 post_ops.entry_.cbegin() + dw_conv_ind);
621 } else {
622 jcp.post_ops = post_ops;
623 }
624
625 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
626 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
627 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
628 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
629 bool is_data_layout_nxc
630 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
631 auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
632
633 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1
634 && src_d.data_type() == data_type::f32;
635
636 if (ok_to_pad_channels) {
637 jcp.oc = rnd_up(jcp.oc, simd_w);
638 jcp.ic = rnd_up(jcp.ic, simd_w);
639 }
640
641 using namespace injector;
642 static constexpr bool sum_at_pos_0_only = true;
643 static constexpr bool sum_requires_scale_one = true;
644 static constexpr bool sum_requires_zp_zero = true;
645 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
646 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
647 sum_requires_zp_zero});
648 if (!post_ops_ok_) return status::unimplemented;
649
650 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag
651 && jcp.dst_tag == required_dat_tag
652 && IMPLICATION(!is_data_layout_nxc,
653 jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0)
654 && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0
655 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1
656 && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw
657 && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0
658 if (!args_ok) return status::unimplemented;
659
660 jcp.ic_block = jcp.oc_block = simd_w;
661
662 const int is_bwd_d = jcp.prop_kind == backward_data;
663 format_tag_t wei_tag = with_groups
664 ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
665 gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, gIOdhw16o16i)
666 : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, OIhw16i16o,
667 IOhw16o16i, OIdhw16i16o, IOdhw16o16i);
668
669 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
670 if (jcp.wei_tag != wei_tag) return status::unimplemented;
671
672 jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
673 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
674
675 /* once all the formats are set, check the padding consistency */
676 if (!is_data_layout_nxc) {
677 args_ok = true && jcp.ic <= src_d.padded_dims()[1]
678 && jcp.oc <= dst_d.padded_dims()[1]
679 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
680 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
681 if (!args_ok) return status::unimplemented;
682 }
683
684 const int SMALL_SPATIAL = 10;
685 const int BIG_SPATIAL = 28;
686 const int BIG_REDUCE_DIM = 1024;
687 const int BIG_LOAD_DIM = 256;
688
689 int load_blocking {0};
690 int load_blocking_max {0};
691 int bcast_blocking {0};
692 int bcast_blocking_max {0};
693 int reduce_blocking {0};
694 int reduce_blocking_max {0};
695
696 jcp.load_grp_count = 1;
697
698 const int L1_capacity
699 = platform::get_per_core_cache_size(1) / sizeof(float);
700 const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float);
701 const int L2_capacity = (L2_size * 3) / 4;
702
703 if (one_of(jcp.prop_kind, forward_training, forward_inference,
704 backward_data)) {
705 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
706 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
707 jcp.reduce_dim = jcp.ic;
708 jcp.reduce_block = jcp.ic_block;
709
710 jcp.load_dim = jcp.oc;
711 jcp.load_block = jcp.oc_block;
712
713 jcp.bcast_dim = jcp.is;
714 } else {
715 jcp.reduce_dim = jcp.oc;
716 jcp.reduce_block = jcp.oc_block;
717
718 jcp.load_dim = jcp.ic;
719 jcp.load_block = jcp.ic_block;
720
721 jcp.bcast_dim = jcp.os;
722 }
723 jcp.reduce_loop_unroll = jcp.reduce_block;
724 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll
725 * (is_data_layout_nxc ? 1 : jcp.bcast_dim) * jcp.typesize_in;
726
727 jcp.reduce_loop_load_step
728 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
729 jcp.load_loop_load_step
730 = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block))
731 * jcp.load_block * jcp.typesize_in;
732
733 // adjusting registry blocking
734 int max_regs, min_regs, size_treshold;
735 const int spatial
736 = (one_of(jcp.prop_kind, forward_training, forward_inference))
737 ? jcp.od * jcp.oh
738 : jcp.id * jcp.ih;
739 if ((8 * jcp.mb) / jcp.nthr >= 1
740 // NHWC perf: RN50 mb=1
741 || (is_data_layout_nxc && jcp.mb == 1)) {
742 max_regs = 9;
743 min_regs = 6;
744 size_treshold = 14;
745 jcp.expl_bcast = true;
746
747 if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
748 && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
749 max_regs = 6;
750 min_regs = 5;
751 }
752 } else {
753 max_regs = 30;
754 min_regs = 9;
755 size_treshold = 14;
756 jcp.expl_bcast = false;
757 jcp.use_vmovntps = true;
758 }
759 jcp.ur = 1;
760 for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
761 if ((spatial >= size_treshold && spatial % ur_w == 0)
762 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
763 jcp.ur = ur_w;
764 break;
765 }
766 }
767 if (jcp.ur == 1) {
768 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
769 int os_tail = jcp.os % max_regs;
770 for (int i = max_regs; i >= min_regs; i--) {
771 int i_tail = jcp.os % i;
772 if (i_tail > os_tail || i_tail == 0) {
773 jcp.ur = i;
774 os_tail = i_tail;
775 if (i_tail == 0) break;
776 }
777 }
778 }
779 jcp.bcast_block = jcp.ur;
780
781 jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out
782 * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block);
783 jcp.bcast_loop_output_substep = -1; // unused
784 jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in
785 * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block);
786 jcp.bcast_loop_bcast_substep = -1; // unused
787
788 jcp.load_loop_iter_step = jcp.load_block;
789
790 if (jcp.prop_kind == backward_data)
791 jcp.loop_order = loop_lbr;
792 else
793 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
794
795 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
796 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
797 int nb_load = div_up(jcp.load_dim, jcp.load_block);
798 if (is_data_layout_nxc) {
799 reduce_blocking = jcp.reduce_dim;
800 } else if (jcp.expl_bcast) {
801 if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
802 && spatial < BIG_SPATIAL)
803 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 80);
804 else if (spatial > SMALL_SPATIAL)
805 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 512);
806 else
807 reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 256);
808 } else {
809 reduce_blocking = nb_reduce;
810 if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
811 reduce_blocking = 16;
812 else if (spatial > SMALL_SPATIAL
813 && jcp.reduce_dim >= BIG_REDUCE_DIM)
814 reduce_blocking = 8;
815 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
816 reduce_blocking *= jcp.reduce_block;
817 }
818
819 // Check input data cache aliasing.
820 // For other ISA constants may be updated.
821 // 64 * 1024 is chosen due to 1MB L2 16-way cache.
822 // 7 is empirical value. It is about half of 16.
823 // So we leave about half of the set for other data - weights, dst
824 int way_size = (64 * 1024) / jcp.typesize_in;
825 int max_hits = 7;
826 if (!is_data_layout_nxc
827 && jcp.bcast_dim * reduce_blocking
828 > static_cast<dim_t>(way_size) * max_hits) {
829 int nrb = reduce_blocking / simd_w;
830 auto sp = jcp.bcast_dim;
831 int wl = way_size / simd_w;
832 for (int start_off = 0; start_off < jcp.ur; start_off++) {
833 for (dim_t off = start_off, hits = 0; off < sp * nrb;
834 off += wl) {
835 if (off % sp >= jcp.ur || ++hits < max_hits) continue;
836 int max_r_blocking
837 = simd_w * nstl::max<dim_t>(1, (off + wl) / sp);
838 reduce_blocking
839 = nstl::min(reduce_blocking, max_r_blocking);
840 break;
841 }
842 }
843 }
844
845 if (reduce_blocking < jcp.reduce_dim) {
846 if (jcp.prop_kind == backward_data)
847 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
848 else
849 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
850 }
851 load_blocking = jcp.load_dim;
852
853 int load_size = jcp.load_dim * jcp.reduce_dim;
854 auto bcast_size
855 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
856
857 if (jcp.nthr <= 28 && jcp.mb < jcp.nthr
858 && nb_load * nb_bcast > jcp.nthr) {
859 // Some heuristic here
860 float calc_koef = 0.01, best_cost = FLT_MAX;
861 int n_lgc = jcp.nthr;
862 float ratio = (float)load_size / (float)bcast_size;
863 int best_lgc = ratio > 1 ? n_lgc : 1;
864 auto calc_job_cost = [&](int lb, int tg, float mem_k) {
865 int bb_size = jcp.mb * div_up(nb_bcast, tg);
866 float calc_size = (float)(bb_size * jcp.ur)
867 * (lb * jcp.load_block) * jcp.reduce_dim;
868 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
869 * jcp.reduce_dim;
870 return calc_koef * calc_size + mem_k * mem_size;
871 };
872 for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
873 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
874 int min_lb = nb_load / lgc;
875 int max_lb = div_up(nb_load, lgc);
876 int min_tg = jcp.nthr / lgc;
877 int max_tg = div_up(jcp.nthr, lgc);
878 // Some heuristic here
879 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
880 float job_cost = 0.;
881 if (jcp.nthr % lgc < nb_load % lgc) {
882 job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
883 } else {
884 auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
885 auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
886 job_cost = nstl::max(job_cost1, job_cost2);
887 }
888
889 if (job_cost < best_cost) {
890 best_lgc = lgc;
891 best_cost = job_cost;
892 }
893 }
894 jcp.load_grp_count = best_lgc;
895 load_blocking
896 = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
897 } else {
898 jcp.load_grp_count
899 = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
900 jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count,
901 2 * jcp.load_grp_count, false);
902 }
903
904 if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) {
905 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
906 } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr
907 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
908 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
909 load_blocking = jcp.load_block;
910 }
911
912 auto get_thr_eff = [=](int load_chunk, int nthr) {
913 int lgc = div_up(nb_load, load_chunk);
914 int thr_per_grp = div_up(nthr, lgc);
915 int bcast_per_thr
916 = div_up(jcp.mb * nb_bcast, thr_per_grp) * jcp.bcast_block;
917 int load_per_thr = load_chunk * simd_w;
918 float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
919 float data_eff
920 = (bcast_per_thr * load_per_thr) / (data_norm * data_norm);
921 float thr_eff_over_grp
922 = (float)nstl::max(1, nthr / lgc) / div_up(nthr, lgc);
923 float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
924 / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
925 float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
926 float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
927 float overall_eff = data_eff + thr_eff + load_eff;
928 return overall_eff;
929 };
930
931 auto get_load_chunk = [=](int nthr) {
932 float best_eff = -1.0f;
933 int best_lgc = 1;
934 float eff;
935
936 for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
937 int lgc = div_up(nb_load, load_chunk);
938 if (lgc > nthr) continue;
939 eff = get_thr_eff(load_chunk, nthr);
940 if (eff > best_eff) {
941 best_eff = eff;
942 best_lgc = lgc;
943 }
944 }
945 return best_lgc;
946 };
947
948 /* adjust the thread decomposition
949 * to improve the thr_eff for small problem size
950 * the threshold 8192 is empirical
951 * TODO: Threshold can be increase for init stride > 1*/
952 if (sizeof(float) * bcast_size < 8192 && jcp.mb < jcp.nthr
953 && nb_load * nb_bcast < jcp.nthr) {
954 float best_thr_eff = -1.0f;
955 float thr_eff = -1.0f;
956 int overall_lgc = jcp.load_grp_count;
957 int lgc = 1;
958 int best_nthr = jcp.nthr;
959 int end_nthr = with_groups ? jcp.ngroups : 1;
960 for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) {
961 lgc = get_load_chunk(nthr);
962 thr_eff = get_thr_eff(lgc, nthr);
963 if (best_thr_eff < thr_eff) {
964 best_thr_eff = thr_eff;
965 overall_lgc = lgc;
966 best_nthr = nthr;
967 }
968 }
969 jcp.nthr = best_nthr;
970 jcp.load_grp_count = overall_lgc;
971 load_blocking
972 = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
973 }
974
975 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
976 div_up(jcp.nthr, jcp.load_grp_count))
977 * jcp.bcast_block;
978 bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking);
979 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
980
981 int space_for_bcast = (L2_capacity - /* kernel_size - */
982 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
983 - 3 * 1024);
984 if (jcp.reduce_dim * jcp.bcast_dim > static_cast<dim_t>(L2_capacity))
985 space_for_bcast /= 2;
986
987 int bcast_in_cache
988 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
989 bcast_blocking = nstl::min(
990 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
991 // NHWC perf
992 if (is_data_layout_nxc) bcast_blocking = jcp.bcast_block;
993
994 load_blocking_max = load_blocking;
995 bcast_blocking_max = bcast_blocking * 3 / 2;
996 reduce_blocking_max = reduce_blocking;
997
998 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
999
1000 } else if (jcp.prop_kind == backward_weights) {
1001 jcp.reduce_dim = jcp.is;
1002
1003 jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1004 if (jcp.reduce_dim % jcp.reduce_block != 0)
1005 jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1006 if (jcp.reduce_block > 256) { jcp.reduce_block = 1; }
1007
1008 jcp.load_dim = jcp.oc;
1009 jcp.load_block = jcp.oc_block;
1010
1011 jcp.bcast_dim = jcp.ic;
1012 jcp.bcast_block = jcp.ic_block;
1013
1014 if (jcp.reduce_block <= 19 &&
1015 // maskrcnn optimization for nxc; don't reduce ur when ocb<=1
1016 !(is_data_layout_nxc && jcp.load_dim <= jcp.load_block)) {
1017 // if reduce_block is big then generated JIT code may be big
1018 // for small values of ur because reduce_loop_unroll = reduce_block
1019 jcp.ur = jcp.bcast_block / 2;
1020 jcp.expl_bcast = true;
1021 } else {
1022 jcp.ur = jcp.bcast_block;
1023 jcp.expl_bcast = false;
1024 }
1025
1026 jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block;
1027 jcp.reduce_loop_unroll = jcp.reduce_block;
1028 jcp.reduce_loop_bcast_step = static_cast<dim_t>(jcp.typesize_in)
1029 * jcp.reduce_loop_unroll
1030 * (is_data_layout_nxc ? jcp.ic : jcp.ic_block);
1031 jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll
1032 * (is_data_layout_nxc ? jcp.oc : jcp.oc_block);
1033
1034 jcp.bcast_loop_output_step
1035 = jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1036 jcp.bcast_loop_output_substep
1037 = jcp.oc_block * jcp.ur * jcp.typesize_out;
1038 jcp.bcast_loop_bcast_step = jcp.ic_block
1039 * (is_data_layout_nxc ? 1
1040 : utils::rnd_up(
1041 jcp.reduce_dim, jcp.reduce_block))
1042 * jcp.typesize_in;
1043 jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1044
1045 jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block
1046 * (is_data_layout_nxc ? 1 : jcp.os);
1047 jcp.load_loop_iter_step = jcp.oc_block;
1048
1049 /* --- */
1050 balance(jcp);
1051
1052 load_blocking = div_up(jcp.load_dim, jcp.load_block);
1053 load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1054 load_blocking *= jcp.load_block;
1055
1056 load_blocking_max = load_blocking;
1057 assert(IMPLICATION(
1058 !is_data_layout_nxc, jcp.load_dim % load_blocking == 0));
1059
1060 int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1061 int min_bcast_blocking = 5;
1062
1063 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1064 bcast_blocking = best_divider(
1065 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1066 bcast_blocking *= jcp.bcast_block;
1067 bcast_blocking_max = bcast_blocking;
1068 assert(IMPLICATION(
1069 !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0));
1070
1071 // for reduction balance
1072 if (is_data_layout_nxc && jcp.reduce_dim >= BIG_SPATIAL * BIG_SPATIAL
1073 && jcp.load_dim >= BIG_LOAD_DIM / 2) {
1074 reduce_blocking = rnd_up(nstl::min(jcp.ow, 256), jcp.reduce_block);
1075 } else {
1076 int max_reduce_blocking
1077 = nstl::min<dim_t>(L1_capacity / jcp.ur, jcp.reduce_dim);
1078 int min_reduce_blocking = nstl::min(
1079 L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1080 reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
1081 max_reduce_blocking, true);
1082 reduce_blocking
1083 = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
1084 jcp.reduce_block);
1085 }
1086
1087 reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1088 } else
1089 return status::unimplemented;
1090
1091 assert(load_blocking);
1092 assert(load_blocking_max);
1093 assert(bcast_blocking);
1094 assert(bcast_blocking_max);
1095 assert(reduce_blocking);
1096 assert(reduce_blocking_max);
1097
1098 if (!is_data_layout_nxc) {
1099 assert(load_blocking % jcp.load_block == 0);
1100 assert(reduce_blocking % jcp.reduce_block == 0);
1101 assert(load_blocking_max % jcp.load_block == 0);
1102 assert(reduce_blocking_max % jcp.reduce_block == 0);
1103 assert(jcp.reduce_dim % jcp.reduce_block == 0);
1104 }
1105
1106 assert(jcp.bcast_block % jcp.ur == 0);
1107
1108 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1109 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1110 jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block);
1111 jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block);
1112 jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block);
1113 jcp.nb_reduce_blocking_max
1114 = utils::div_up(reduce_blocking_max, jcp.reduce_block);
1115
1116 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1117 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1118 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1119
1120 return status::success;
1121}
1122
1123void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
1124 memory_tracking::registrar_t &scratchpad,
1125 const jit_1x1_conv_conf_t &jcp) {
1126 using namespace dnnl::impl::memory_tracking::names;
1127
1128 // Fox nxc layout bias is padded only for bwd_wb direction, as bias
1129 // reduction kernels can't handle tails yet.
1130 if (jcp.with_bias && jcp.prop_kind != backward_data
1131 && (jcp.oc != jcp.oc_without_padding // blocked layout
1132 || (jcp.prop_kind == backward_weights // nxc layout
1133 && jcp.oc % jcp.oc_block != 0))) {
1134
1135 const size_t nelems_padded_bias
1136 = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block);
1137 scratchpad.book(
1138 key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out);
1139 }
1140
1141 if (jcp.prop_kind == backward_weights) {
1142 const size_t wei_size = (size_t)jcp.ngroups
1143 * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block);
1144 scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1),
1145 jcp.typesize_out);
1146 }
1147}
1148
1149void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) {
1150 int nthreads = jcp.nthr;
1151 // initialize jcp reduction threading properties
1152 jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1153 if (nthreads < jcp.ngroups) {
1154 /* simplification... fortunately it doesn't hurt much */
1155 return;
1156 }
1157 const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1158 const int nb_load = div_up(jcp.load_dim, jcp.load_block);
1159 const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1160
1161 jcp.nthr_g = jcp.ngroups;
1162 const int nthr = nthreads / jcp.nthr_g;
1163
1164 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1165 /* calculate per thread memory cost (read/write). high level
1166 * optimizer tries to minimize memory consumption. few notes: (n1)
1167 * unclear why, but that essentially helps first convolution...
1168 * (n2) assuming the reduction over minibatch is always there:
1169 * - instead of 8 it should be 5 here (write ~= 2 read):
1170 * kernel: temporal workspace 1 write
1171 * reduction: 1 read from workspace and 1 write to the diff_wei
1172 * - but experiments showed 8 works better than 5 or 6... */
1173 int bcast_koeff = 1;
1174 int load_koeff = 1;
1175 int output_koeff = 12;
1176 return 0
1177 + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1178 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b)
1179 * jcp.ic_block * jcp.reduce_block / jcp.stride_h
1180 / jcp.stride_w /* (n1) */
1181 + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1182 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1183 * jcp.oc_block * jcp.reduce_block
1184 + (size_t)output_koeff /* (n2) */
1185 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1186 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block;
1187 };
1188
1189 int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1190 auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1191
1192 /* step 1: find the best thread distribution with lowest memory cost */
1193 const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1194 for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1195 const int nthr_par = nthr / nthr_mb;
1196 const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1197 for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1198 nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1199 auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1200 if (mem_cost <= best_mem_cost) {
1201 best_mem_cost = mem_cost;
1202 jcp.nthr_mb = nthr_mb;
1203 jcp.nthr_oc_b = nthr_oc_b;
1204 jcp.nthr_ic_b = nthr_ic_b;
1205 }
1206 }
1207 }
1208 if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1209 jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1210
1211 jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1212 assert(jcp.nthr <= nthreads);
1213}
1214
1215} // namespace x64
1216} // namespace cpu
1217} // namespace impl
1218} // namespace dnnl
1219