1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
24#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
25#include "cpu/x64/jit_sse41_1x1_conv_kernel_f32.hpp"
26#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
27
28#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::format_tag;
36using namespace dnnl::impl::prop_kind;
37using namespace dnnl::impl::utils;
38
39using namespace Xbyak;
40
41jit_sse41_1x1_conv_kernel_f32::jit_sse41_1x1_conv_kernel_f32(
42 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
43 const memory_desc_t &dst_md)
44 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, sse41)
45 , jcp(ajcp)
46 , attr_(attr) {
47 if (jcp.with_eltwise || jcp.with_binary) {
48 static constexpr bool preserve_gpr = true;
49 static constexpr bool preserve_vmm = false;
50 static constexpr size_t helper_vmm_idx = 15;
51 const size_t tail_size = jcp.oc_without_padding % simd_w_;
52 static constexpr bool use_exact_tail_scalar_bcast = false;
53
54 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
55 helper_vmm_idx, r13, r14, r15, preserve_gpr, preserve_vmm,
56 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
57 memory_desc_wrapper(dst_md), tail_size,
58 use_exact_tail_scalar_bcast};
59 const binary_injector::static_params_t static_params {
60 this->param1, rhs_arg_static_params};
61 postops_injector_ = utils::make_unique<
62 injector::jit_uni_postops_injector_t<sse41>>(
63 this, jcp.post_ops, static_params);
64 }
65}
66
67void jit_sse41_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) {
68 mov(aux1_reg_bcast_data, reg_bcast_data);
69 mov(aux_reg_output_data, reg_output_data);
70 mov(bcast_loop_iter, reg_bcast_loop_work);
71
72 Label bcast_loop;
73 Label bcast_loop_tail;
74
75 cmp(bcast_loop_iter, jcp.ur);
76 jl(bcast_loop_tail, T_NEAR);
77
78 L(bcast_loop);
79 {
80 assert(jcp.bcast_block % jcp.ur == 0);
81 int num_substeps = jcp.bcast_block / jcp.ur;
82 assert(num_substeps > 0 && num_substeps < 10);
83 for (int i = 0; i < num_substeps; i++) {
84 generate_reduce_loop(load_loop_blk, jcp.ur);
85 if (i < num_substeps - 1) {
86 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
87 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
88 } else {
89 add(aux1_reg_bcast_data,
90 jcp.bcast_loop_bcast_step
91 - (num_substeps - 1)
92 * jcp.bcast_loop_bcast_substep);
93 add(aux_reg_output_data,
94 jcp.bcast_loop_output_step
95 - (num_substeps - 1)
96 * jcp.bcast_loop_output_substep);
97 }
98 }
99 sub(bcast_loop_iter, jcp.bcast_block);
100 cmp(bcast_loop_iter, jcp.bcast_block);
101 jge(bcast_loop, T_NEAR);
102 }
103
104 L(bcast_loop_tail);
105 if (jcp.ur_tail) {
106 Label bcast_loop_tail_out;
107 cmp(bcast_loop_iter, 0);
108 jz(bcast_loop_tail_out, T_NEAR);
109 generate_reduce_loop(load_loop_blk, jcp.ur_tail);
110 L(bcast_loop_tail_out);
111 }
112}
113
114size_t jit_sse41_1x1_conv_kernel_f32::get_fwd_output_ptr_l_off(
115 int i, int j, int n) const {
116 return i * get_output_i_offset(jcp) + j * get_output_j_offset(jcp) + n * 4;
117}
118
119static int reg_accum_idx(
120 const int load_loop_blk, const int i, const int j, const int n) {
121 return 2 * j * load_loop_blk + 2 * i + n + 1;
122}
123
124template <typename F>
125static void iterate(const int load_loop_blk, const int ur, const F &f) {
126 for (int j = 0; j < ur; ++j)
127 for (int i = 0; i < load_loop_blk; ++i)
128 for (int n = 0; n < 2; n++)
129 f(i, j, n);
130}
131void jit_sse41_1x1_conv_kernel_f32::apply_postops(
132 const int load_loop_blk, const int ur) {
133 injector_utils::vmm_index_set_t vmm_idxs;
134 if (jcp.with_binary) {
135 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
136 iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) {
137 const bool mask_flag = (2 * i + n) == load_loop_blk - 1;
138 const size_t aux_output_offset
139 = get_fwd_output_ptr_l_off(i, j, n) * sizeof(float);
140 const auto vmm_idx = reg_accum_idx(load_loop_blk, i, j, n);
141 vmm_idxs.emplace(vmm_idx);
142
143 rhs_arg_params.vmm_idx_to_out_reg.emplace(
144 vmm_idx, aux_reg_output_data);
145 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
146 vmm_idx, aux_output_offset);
147 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
148 });
149 const injector_utils::register_preserve_guard_t register_guard(
150 this, {abi_param1});
151 const size_t reg_guard_stack_occupied
152 = register_guard.stack_space_occupied();
153 mov(abi_param1,
154 ptr[rsp + reg_abi_param1_backup + reg_guard_stack_occupied]);
155
156 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
157 } else {
158 iterate(load_loop_blk, ur, [&](const int i, const int j, const int n) {
159 vmm_idxs.emplace(reg_accum_idx(load_loop_blk, i, j, n));
160 });
161 postops_injector_->compute_vector_range(vmm_idxs);
162 }
163}
164
165void jit_sse41_1x1_conv_kernel_f32::generate_reduce_loop(
166 int load_loop_blk, int ur) {
167 auto reg_load = [=](int i, int n) {
168 return Xmm(2 * ur * load_loop_blk + 2 * i + n + 1);
169 };
170
171 auto reg_accum = [=](int i, int j, int n) {
172 return Xmm(reg_accum_idx(load_loop_blk, i, j, n));
173 };
174
175 auto bias_ptr = [=](int i, int n) {
176 return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i
177 + n * 4 * sizeof(float)];
178 };
179
180 auto bcast_ptr = [=](int u, int j) {
181 assert(j < jcp.ur);
182 assert(u <= jcp.reduce_loop_unroll);
183 size_t offt;
184 if (one_of(jcp.prop_kind, forward_training, forward_inference,
185 backward_data)) {
186 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
187 offt = get_bcast_offset(jcp, u, j);
188 } else
189 offt = u * jcp.ic_block + j;
190 return ptr[aux_reg_bcast_data + offt];
191 };
192
193 auto load_ptr = [=](int u, int i, int n) {
194 size_t offt;
195 size_t u0 = u % jcp.reduce_loop_unroll;
196 size_t u1 = u / jcp.reduce_loop_unroll;
197 switch (jcp.prop_kind) {
198 case backward_data:
199 offt = (i * jcp.oc_block + u0) * jcp.ic_block;
200 break;
201 case backward_weights:
202 offt = (i * jcp.os + u0) * jcp.oc_block;
203 break;
204 default: offt = (i * jcp.ic + u0) * jcp.oc_block;
205 }
206 return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step
207 + sizeof(float) * offt + n * 4 * sizeof(float)];
208 };
209
210 auto output_ptr = [=](int i, int j, int n) {
211 switch (jcp.prop_kind) {
212 case backward_data:
213 return ptr[aux_reg_output_data
214 + (i * jcp.is + j) * jcp.ic_block * sizeof(float)
215 + n * 4 * sizeof(float)];
216 case backward_weights:
217 return ptr[aux_reg_output_data
218 + (i ? reg_output_stride * i
219 : 0) // TODO: Xbyak should allow 0 scale
220 + sizeof(float) * jcp.oc_block * j
221 + n * 4 * sizeof(float)];
222 default:
223 return ptr[aux_reg_output_data
224 + get_fwd_output_ptr_l_off(i, j, n) * sizeof(float)];
225 }
226 };
227
228 auto init = [=]() {
229 Label init_done;
230 Label init_zero;
231
232 if (jcp.with_bias
233 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
234 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
235 jz(init_zero);
236
237 for (int i = 0; i < load_loop_blk; i++)
238 for (int j = 0; j < ur; ++j) {
239 movups(reg_accum(i, j, 0), bias_ptr(i, 0));
240 movups(reg_accum(i, j, 1), bias_ptr(i, 1));
241 }
242 jmp(init_done);
243 }
244
245 L(init_zero);
246 for (int i = 0; i < load_loop_blk; ++i)
247 for (int j = 0; j < ur; ++j) {
248 auto r0 = reg_accum(i, j, 0);
249 auto r1 = reg_accum(i, j, 1);
250 xorps(r0, r0);
251 xorps(r1, r1);
252 }
253
254 L(init_done);
255
256 // load weights
257 for (int i = 0; i < load_loop_blk; ++i) {
258 movups(reg_load(i, 0), load_ptr(0, i, 0));
259 movups(reg_load(i, 1), load_ptr(0, i, 1));
260 }
261
262 movss(reg_bcast, bcast_ptr(0, 0));
263 shufps(reg_bcast, reg_bcast, 0);
264 }; // init()
265
266 auto store = [=]() {
267 Label store_noadd;
268
269 if (!jcp.with_sum) {
270 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
271 jnz(store_noadd, T_NEAR);
272 }
273
274 for (int j = 0; j < ur; ++j)
275 for (int i = 0; i < load_loop_blk; ++i) {
276 auto r0 = reg_accum(i, j, 0);
277 auto r1 = reg_accum(i, j, 1);
278 addps(r0, output_ptr(i, j, 0));
279 addps(r1, output_ptr(i, j, 1));
280 }
281
282 L(store_noadd);
283
284 if (jcp.with_eltwise || jcp.with_binary) {
285 assert(ur * load_loop_blk < 14);
286
287 Label store_nopostops;
288 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
289 jz(store_nopostops, T_NEAR);
290
291 apply_postops(load_loop_blk, ur);
292
293 L(store_nopostops);
294 }
295
296 for (int j = 0; j < ur; ++j)
297 for (int i = 0; i < load_loop_blk; ++i) {
298 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
299 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
300 }
301 };
302
303 auto fma_block = [=](bool last_block) {
304 for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
305 for (int j = 0; j < ur; ++j) {
306 for (int i = 0; i < load_loop_blk; ++i) {
307 mulps(reg_load(i, 0), reg_bcast);
308 mulps(reg_load(i, 1), reg_bcast);
309 addps(reg_accum(i, j, 0), reg_load(i, 0));
310 addps(reg_accum(i, j, 1), reg_load(i, 1));
311
312 if (j == ur - 1
313 && !(last_block
314 && u == jcp.reduce_loop_unroll - 1)) {
315 movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
316 movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
317 }
318 }
319 if (j < ur - 1) {
320 movss(reg_bcast, bcast_ptr(u, j + 1));
321 shufps(reg_bcast, reg_bcast, 0);
322 }
323 } // for ur
324 if (!last_block || u < jcp.reduce_loop_unroll - 1) {
325 movss(reg_bcast, bcast_ptr(u + 1, 0));
326 shufps(reg_bcast, reg_bcast, 0);
327 }
328 } // for reduce_loop_unroll
329 };
330
331 Label reduce_loop;
332 Label reduce_loop_tail;
333
334 mov(aux_reg_load_data, reg_load_data);
335 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
336
337 init();
338
339 mov(reduce_loop_iter, reg_reduce_loop_work);
340 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
341 jle(reduce_loop_tail, T_NEAR);
342
343 L(reduce_loop);
344 {
345 fma_block(false);
346 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
347 add(aux_reg_load_data, jcp.reduce_loop_load_step);
348 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
349 jg(reduce_loop, T_NEAR);
350 }
351
352 L(reduce_loop_tail);
353 fma_block(true);
354
355 store();
356} // reduce_loop()
357
358void jit_sse41_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) {
359 if (!jcp.with_bias || jcp.prop_kind != backward_weights) return;
360
361 Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
362 Label diff_bias_load;
363
364 auto diff_bias_ptr = [=](int i, int n) {
365 return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)
366 + 4 * n * sizeof(float)];
367 };
368
369 auto load_ptr = [=](int u, int i, int n) {
370 return ptr[aux_reg_load_data
371 + (i * jcp.os + u) * jcp.oc_block * sizeof(float)
372 + 4 * n * sizeof(float)];
373 };
374
375 auto diff_bias_reg = [=](int i, int n) { return Xmm(2 * i + n + 1); };
376
377 mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
378 cmp(reg_diff_bias_data, 0);
379 je(diff_bias_loop_out, T_NEAR);
380
381 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
382 jz(diff_bias_load, T_NEAR);
383
384 for (int i = 0; i < load_loop_blk; ++i) {
385 auto r0 = diff_bias_reg(i, 0);
386 auto r1 = diff_bias_reg(i, 1);
387 xorps(r0, r0);
388 xorps(r1, r1);
389 }
390 jmp(diff_bias_init_out, T_NEAR);
391
392 L(diff_bias_load);
393 for (int i = 0; i < load_loop_blk; ++i) {
394 movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
395 movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
396 }
397
398 L(diff_bias_init_out);
399 mov(aux_reg_load_data, reg_load_data);
400 mov(reduce_loop_iter, reg_reduce_loop_work);
401 L(diff_bias_loop);
402 {
403 for (int u = 0; u < jcp.reduce_loop_unroll; ++u)
404 for (int i = 0; i < load_loop_blk; ++i) {
405 addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
406 addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
407 }
408 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
409 add(aux_reg_load_data, jcp.reduce_loop_load_step);
410 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
411 jnz(diff_bias_loop, T_NEAR);
412 }
413
414 for (int i = 0; i < load_loop_blk; i++) {
415 movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
416 movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
417 }
418
419 add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
420 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
421
422 L(diff_bias_loop_out);
423}
424
425void jit_sse41_1x1_conv_kernel_f32::generate() {
426 preamble();
427
428 sub(rsp, stack_space_needed);
429 if (jcp.with_binary) {
430 // backup abi_param1 for usage in post_ops processing
431 mov(ptr[rsp + reg_abi_param1_backup], abi_param1);
432
433 // zero initialize binary post_ops offset accumulator (store on stack)
434 const auto zeroed_reg = r15;
435 xor_(zeroed_reg, zeroed_reg);
436 mov(ptr[rsp + reg_binary_post_op_acc_off], zeroed_reg);
437 }
438
439 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
440 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
441 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
442 if (jcp.with_bias) {
443 if (jcp.prop_kind == backward_weights) {
444 mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
445 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
446 } else
447 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
448 }
449
450 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
451 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
452 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
453 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
454 if (jcp.prop_kind == backward_weights)
455 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
456
457 auto generate_load_loop_body = [=](int load_loop_blk) {
458 generate_bcast_loop(load_loop_blk);
459 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
460 switch (jcp.prop_kind) {
461 case forward_training:
462 case forward_inference:
463 add(reg_bias_data,
464 load_loop_blk * jcp.oc_block * sizeof(float));
465 add(reg_output_data,
466 get_load_loop_output_fwd_offset(jcp, load_loop_blk));
467 if (jcp.with_binary) {
468 mov(aux_reg_load_data,
469 EVEX_compress_addr(
470 rsp, reg_binary_post_op_acc_off));
471 add(aux_reg_load_data, jcp.load_block * load_loop_blk);
472 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off),
473 aux_reg_load_data);
474 }
475 break;
476 case backward_data:
477 add(reg_output_data,
478 load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
479 break;
480 case backward_weights:
481 for (int i = 0; i < load_loop_blk; i++)
482 add(reg_output_data, reg_output_stride);
483 break;
484 default: assert(!"invalid prop_kind");
485 }
486 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
487 };
488
489 Label load_loop_blk_8;
490 Label load_loop_blk_16;
491 Label load_loop_blk_24;
492 Label load_loop_blk_end;
493
494 cmp(reg_load_loop_work, 8);
495 jle(load_loop_blk_8, T_NEAR);
496
497 cmp(reg_load_loop_work, 32);
498 je(load_loop_blk_16, T_NEAR);
499
500 cmp(reg_load_loop_work, 16);
501 jle(load_loop_blk_16, T_NEAR);
502
503 L(load_loop_blk_24);
504 {
505 generate_diff_bias_loop(3);
506 generate_load_loop_body(3);
507 cmp(reg_load_loop_work, 32);
508 je(load_loop_blk_16);
509 cmp(reg_load_loop_work, 24);
510 jge(load_loop_blk_24);
511 }
512
513 cmp(reg_load_loop_work, 8);
514 jle(load_loop_blk_8, T_NEAR);
515
516 L(load_loop_blk_16);
517 {
518 generate_diff_bias_loop(2);
519 generate_load_loop_body(2);
520 cmp(reg_load_loop_work, 16);
521 jge(load_loop_blk_16);
522 }
523
524 L(load_loop_blk_8);
525 {
526 cmp(reg_load_loop_work, 0);
527 je(load_loop_blk_end, T_NEAR);
528 generate_diff_bias_loop(1);
529 generate_load_loop_body(1);
530 }
531
532 L(load_loop_blk_end);
533
534 add(rsp, stack_space_needed);
535
536 postamble();
537
538 if (jcp.with_eltwise) postops_injector_->prepare_table();
539}
540
541status_t jit_sse41_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
542 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
543 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
544 const primitive_attr_t &attr, int nthreads) {
545 if (!mayiuse(sse41)) return status::unimplemented;
546
547 // TODO (Roma): this code is duplicated from the generic kernel; maybe the
548 // configuration struct could do some stuff below
549 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
550 const int ndims = src_d.ndims();
551
552 jcp.nthr = nthreads;
553
554 jcp.prop_kind = cd.prop_kind;
555
556 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
557 jcp.mb = src_d.dims()[0];
558
559 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
560 jcp.ic = src_d.dims()[1] / jcp.ngroups;
561
562 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
563 jcp.iw = src_d.dims()[ndims - 1];
564 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
565 jcp.ow = dst_d.dims()[ndims - 1];
566
567 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
568 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
569
570 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
571 jcp.l_pad = cd.padding[0][ndims - 3];
572
573 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
574 jcp.stride_w = cd.strides[ndims - 3];
575
576 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
577
578 jcp.os = jcp.oh * jcp.ow;
579 jcp.is = jcp.ih * jcp.iw;
580
581 jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
582 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
583
584 const auto &post_ops = attr.post_ops_;
585
586 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
587 jcp.with_dw_conv = dw_conv_ind != -1;
588 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
589 // handled in depthwise convolution.
590 jcp.with_sum = post_ops.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
591 const int eltwise_ind
592 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
593 jcp.with_eltwise = eltwise_ind != -1;
594 const int binary_ind
595 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
596 jcp.with_binary = binary_ind != -1;
597
598 if (dw_conv_ind >= 0) {
599 // dw_conv and post_ops after it are handled externally, so skip them
600 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
601 post_ops.entry_.cbegin() + dw_conv_ind);
602 } else {
603 jcp.post_ops = post_ops;
604 }
605
606 using namespace injector;
607 static constexpr bool sum_at_pos_0_only = true;
608 static constexpr bool sum_requires_scale_one = true;
609 static constexpr bool sum_requires_zp_zero = true;
610 const bool post_ops_ok_ = post_ops_ok({sse41, {eltwise, binary, sum},
611 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
612 sum_requires_zp_zero});
613 if (!post_ops_ok_) return status::unimplemented;
614
615 const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc);
616 const auto dat_tag_blocked = utils::pick(ndims - 3, nCw8c, nChw8c);
617 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
618 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_blocked);
619 const bool is_data_layout_nxc
620 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
621 const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_blocked;
622
623 const int is_bwd_d = jcp.prop_kind == backward_data;
624 format_tag_t wei_tag = with_groups
625 ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i,
626 gOIhw8i8o, gOIhw8o8i)
627 : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
628 OIhw8o8i);
629 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
630
631 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag
632 && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag;
633 if (!args_ok) return status::unimplemented;
634
635 const int simd_w = 4;
636
637 jcp.ic_block = jcp.oc_block = simd_w * 2;
638
639 args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
640 && jcp.t_pad == 0 && jcp.l_pad == 0 && jcp.stride_w == 1
641 && jcp.stride_h == 1 // TODO: support some strides
642 && jcp.ow == jcp.iw && jcp.oh == jcp.ih // enforce rpad=0
643 && jcp.kh == 1 && jcp.kw == 1;
644 if (!args_ok) return status::unimplemented;
645
646 jcp.ur = 1;
647 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
648
649 int load_blocking {0};
650 int load_blocking_max {0};
651 int bcast_blocking {0};
652 int bcast_blocking_max {0};
653 int reduce_blocking {0};
654
655 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
656 jcp.reduce_dim = jcp.ic;
657 jcp.reduce_block = jcp.ic_block;
658
659 jcp.load_dim = jcp.oc;
660 jcp.load_block = jcp.oc_block;
661
662 jcp.bcast_dim = jcp.is;
663 jcp.bcast_block = jcp.ur;
664
665 jcp.reduce_loop_unroll = jcp.reduce_block;
666 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll
667 * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float);
668 jcp.reduce_loop_load_step
669 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
670
671 jcp.bcast_loop_output_step = jcp.ur
672 * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float);
673 jcp.bcast_loop_output_substep = -1; // unused
674 jcp.bcast_loop_bcast_step = jcp.ur
675 * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float);
676 jcp.bcast_loop_bcast_substep = -1; // unused
677
678 jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
679 jcp.load_loop_iter_step = jcp.oc_block;
680
681 load_blocking = 120; // assumes the kernel is jcp.ur x 3
682 load_blocking_max = 144;
683 bcast_blocking = 128; // affects load balancing across threads
684 bcast_blocking_max = 192;
685 reduce_blocking = 128; // affects L1$ utilization
686 } else if (jcp.prop_kind == backward_data) {
687 jcp.reduce_dim = jcp.oc;
688 jcp.reduce_block = jcp.oc_block;
689
690 jcp.load_dim = jcp.ic;
691 jcp.load_block = jcp.oc_block;
692
693 jcp.bcast_dim = jcp.os;
694 jcp.bcast_block = jcp.ur;
695
696 jcp.reduce_loop_unroll = jcp.reduce_block;
697 jcp.reduce_loop_bcast_step
698 = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
699 jcp.reduce_loop_load_step
700 = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
701
702 jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
703 jcp.bcast_loop_output_substep = -1; // unused
704 jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
705 jcp.bcast_loop_bcast_substep = -1; // unused
706
707 jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
708 jcp.load_loop_iter_step = jcp.ic_block;
709
710 load_blocking = 96; // assumes the kernel is jcp.ur x 3
711 load_blocking_max = 144;
712 bcast_blocking = 128; // affects load balancing across threads
713 bcast_blocking_max = 196;
714 reduce_blocking = 64; // affects L1$ utilization
715 } else if (jcp.prop_kind == backward_weights) {
716 jcp.reduce_dim = jcp.os;
717 jcp.reduce_block = 1;
718
719 jcp.load_dim = jcp.oc;
720 jcp.load_block = jcp.oc_block;
721
722 jcp.bcast_dim = jcp.ic;
723 jcp.bcast_block = jcp.ic_block;
724
725 jcp.reduce_loop_unroll = jcp.reduce_block;
726 jcp.reduce_loop_bcast_step
727 = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
728 jcp.reduce_loop_load_step
729 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
730
731 jcp.bcast_loop_output_step
732 = jcp.oc_block * jcp.ic_block * sizeof(float);
733 jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
734 jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
735 jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
736
737 jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
738 jcp.load_loop_iter_step = jcp.oc_block;
739
740 /* --- */
741
742 load_blocking = div_up(jcp.load_dim, jcp.load_block);
743 while (true) {
744 if (load_blocking <= 32)
745 break;
746 else if (load_blocking % 2 == 0)
747 load_blocking /= 2;
748 else if (load_blocking % 3 == 0)
749 load_blocking /= 3;
750 else
751 break;
752 }
753 load_blocking *= jcp.load_block;
754 load_blocking_max = load_blocking;
755 assert(jcp.load_dim % load_blocking == 0);
756
757 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
758 while (true) {
759 if (bcast_blocking <= 9)
760 break;
761 else if (bcast_blocking % 2 == 0)
762 bcast_blocking /= 2;
763 else if (bcast_blocking % 3 == 0)
764 bcast_blocking /= 3;
765 else
766 break;
767 }
768 bcast_blocking *= jcp.bcast_block;
769 bcast_blocking_max = bcast_blocking;
770 assert(jcp.bcast_dim % bcast_blocking == 0);
771
772 reduce_blocking = 128; // affects L1$ utilization
773 } else
774 return status::unimplemented;
775
776 assert(load_blocking);
777 assert(load_blocking_max);
778 assert(bcast_blocking);
779 assert(bcast_blocking_max);
780 assert(reduce_blocking);
781
782 assert(jcp.bcast_block % jcp.ur == 0);
783 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
784
785 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
786 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
787 jcp.nb_load_blocking = load_blocking / jcp.load_block;
788 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
789 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
790
791 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
792 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
793 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
794
795 return status::success;
796}
797
798} // namespace x64
799} // namespace cpu
800} // namespace impl
801} // namespace dnnl
802