1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3* Copyright 2018 YANDEX LLC
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include <assert.h>
19#include <limits>
20
21#include "common/c_types_map.hpp"
22#include "common/memory.hpp"
23#include "common/memory_tracking.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/x64/injectors/injector_utils.hpp"
29#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
30#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
31#include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp"
32#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
33
34#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace dnnl::impl::prop_kind;
42using namespace dnnl::impl::format_tag;
43using namespace dnnl::impl::utils;
44
45using namespace Xbyak;
46
47jit_avx2_1x1_conv_kernel_f32::jit_avx2_1x1_conv_kernel_f32(
48 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
49 const memory_desc_t &dst_md)
50 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx2)
51 , jcp(ajcp)
52 , attr_(attr) {
53 if (jcp.with_eltwise || jcp.with_binary) {
54 using namespace binary_injector;
55 static constexpr bool preserve_gpr = true;
56 static constexpr bool preserve_vmm = false;
57 static constexpr size_t helper_vmm_idx = 15;
58 static constexpr bool use_exact_tail_scalar_bcast = false;
59 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
60
61 rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, r13, r14,
62 r15, preserve_gpr, preserve_vmm,
63 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
64 memory_desc_wrapper(dst_md), tail_size,
65 use_exact_tail_scalar_bcast};
66 static_params_t static_params {this->param1, rhs_arg_static_params};
67
68 postops_injector_ = utils::make_unique<
69 injector::jit_uni_postops_injector_t<avx2>>(
70 this, jcp.post_ops, static_params);
71 }
72}
73
74void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) {
75 mov(aux1_reg_bcast_data, ptr[rsp + reg_bcast_data_off]);
76 mov(aux_reg_output_data, reg_output_data);
77 mov(bcast_loop_iter, reg_bcast_loop_work);
78
79 Label bcast_loop, bcast_loop_tail, large_tail;
80
81 cmp(bcast_loop_iter, jcp.bcast_block);
82 jl(bcast_loop_tail, T_NEAR);
83
84 L(bcast_loop);
85 {
86 assert(jcp.bcast_block % jcp.ur == 0);
87 const int num_substeps = jcp.bcast_block / jcp.ur;
88 assert(num_substeps > 0 && num_substeps < 10);
89 for (int i = 0; i < num_substeps; i++) {
90 if (i == num_substeps - 1) L(large_tail);
91 generate_reduce_loop(load_loop_blk, jcp.ur);
92 if (i < num_substeps - 1) {
93 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
94 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
95 } else {
96 add(aux1_reg_bcast_data,
97 jcp.bcast_loop_bcast_step
98 - (num_substeps - 1)
99 * jcp.bcast_loop_bcast_substep);
100 add(aux_reg_output_data,
101 jcp.bcast_loop_output_step
102 - (num_substeps - 1)
103 * jcp.bcast_loop_output_substep);
104 }
105 sub(bcast_loop_iter, jcp.ur);
106 }
107 cmp(bcast_loop_iter, jcp.bcast_block);
108 jge(bcast_loop, T_NEAR);
109 }
110
111 L(bcast_loop_tail);
112 if (jcp.ur_tail) {
113 Label bcast_loop_tail_out;
114 if (jcp.ur_tail >= jcp.ur) {
115 cmp(bcast_loop_iter, jcp.ur);
116 jge(large_tail, T_NEAR);
117 }
118 if (jcp.ur_tail % jcp.ur > 0) {
119 cmp(bcast_loop_iter, 0);
120 jle(bcast_loop_tail_out, T_NEAR);
121 generate_reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur);
122 L(bcast_loop_tail_out);
123 }
124 }
125}
126
127static int vreg_accum_idx(const int load_loop_blk, int i, int j) {
128 return (j * load_loop_blk + i);
129}
130
131static Ymm vreg_accum(const int load_loop_blk, int i, int j) {
132 return Ymm(vreg_accum_idx(load_loop_blk, i, j));
133}
134
135template <typename F>
136void iterate(const int load_loop_blk, const int ur, const int load_dim_tail,
137 const F &f) {
138 for (int i = 0; i < load_loop_blk; ++i) {
139 const bool mask_flag = (load_dim_tail > 0) && (i == load_loop_blk - 1);
140 for (int j = 0; j < ur; ++j)
141 f(mask_flag, i, j);
142 }
143}
144template <typename F>
145void iterate(const int load_loop_blk, const int ur, const F &f) {
146 iterate(load_loop_blk, ur, 0, f);
147}
148
149void jit_avx2_1x1_conv_kernel_f32::apply_postops(
150 const int load_loop_blk, const int ur, const int load_dim_tail) {
151 if (jcp.with_eltwise || jcp.with_binary) {
152 assert(ur * load_loop_blk < 14);
153
154 Label store_nopost_ops;
155 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
156 jz(store_nopost_ops, T_NEAR);
157
158 injector_utils::vmm_index_set_t vmm_idxs;
159 if (jcp.with_binary) {
160 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
161 rhs_arg_params_tail;
162
163 iterate(load_loop_blk, ur, load_dim_tail,
164 [&](const bool mask_flag, const int i, const int j) {
165 const size_t aux_output_offset
166 = (i * get_output_i_offset(jcp)
167 + j * get_output_j_offset(jcp))
168 * sizeof(float);
169 const auto vmm_idx
170 = vreg_accum_idx(load_loop_blk, i, j);
171 vmm_idxs.emplace(vmm_idx);
172
173 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
174 vmm_idx, aux_reg_output_data);
175 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
176 vmm_idx, aux_output_offset);
177 if (mask_flag)
178 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
179 });
180 rhs_arg_params = rhs_arg_params_tail;
181 rhs_arg_params.vmm_tail_idx_.clear();
182
183 const injector_utils::register_preserve_guard_t register_guard(
184 this, {abi_param1});
185 const size_t reg_guard_stack_occupied
186 = register_guard.stack_space_occupied();
187 mov(abi_param1,
188 ptr[rsp + reg_abi_param1_backup
189 + reg_guard_stack_occupied]);
190
191 Label postops_done;
192 if (load_dim_tail) {
193 Label postops_no_tail;
194 cmp(reg_load_loop_work,
195 load_loop_blk * jcp.load_loop_iter_step);
196 jge(postops_no_tail, T_NEAR);
197 postops_injector_->compute_vector_range(
198 vmm_idxs, rhs_arg_params_tail);
199 jmp(postops_done, T_NEAR);
200 L(postops_no_tail);
201 }
202 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
203 L(postops_done);
204 } else {
205 iterate(load_loop_blk, ur, load_dim_tail,
206 [&](const bool, const int i, const int j) {
207 vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i, j));
208 });
209 postops_injector_->compute_vector_range(vmm_idxs);
210 }
211 L(store_nopost_ops);
212 }
213};
214
215void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
216 int load_loop_blk, int ur) {
217 const int load_dim_tail
218 = ((jcp.with_binary
219 && one_of(jcp.prop_kind, forward_training,
220 forward_inference))
221 ? jcp.oc_without_padding
222 : jcp.load_dim)
223 % jcp.load_block;
224 const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
225
226 auto vreg_load = [=](int i) { return Ymm(ur * load_loop_blk + i); };
227
228 auto bias_ptr = [=](int i) {
229 return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i];
230 };
231
232 auto bcast_ptr = [=](int u, int j) {
233 assert(j < jcp.ur);
234 assert(u <= jcp.reduce_loop_unroll);
235 const size_t offset = get_bcast_offset(jcp, u, j);
236 return make_safe_addr(aux_reg_bcast_data, offset, reg_long_offt);
237 };
238
239 auto get_load_offset_bwd_w = [=](int u, int i) {
240 size_t u0 = u % jcp.reduce_loop_unroll;
241 size_t u1 = u / jcp.reduce_loop_unroll;
242 return u1 * jcp.reduce_loop_load_step
243 + sizeof(float) * get_load_bwd_w_offset(jcp, i, u0);
244 };
245
246 auto load_ptr = [=](int u, int i) {
247 size_t offt;
248 size_t u0 = u % jcp.reduce_loop_unroll;
249 size_t u1 = u / jcp.reduce_loop_unroll;
250 switch (jcp.prop_kind) {
251 case backward_data:
252 offt = (i * jcp.oc_block + u0) * jcp.ic_block;
253 break;
254 case backward_weights:
255 offt = get_load_bwd_w_offset(jcp, i, u0);
256 break;
257 default:
258 offt = (i * rnd_up(jcp.ic, jcp.ic_block) + u0) * jcp.oc_block;
259 }
260 return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step
261 + sizeof(float) * offt];
262 };
263
264 auto get_output_offset = [=](int i, int j) {
265 switch (jcp.prop_kind) {
266 case backward_weights: return sizeof(float) * jcp.oc_block * j;
267 default:
268 return (i * get_output_i_offset(jcp)
269 + j * get_output_j_offset(jcp))
270 * sizeof(float);
271 }
272 };
273
274 auto output_ptr = [=](int i, int j) {
275 switch (jcp.prop_kind) {
276 case backward_weights:
277 return ptr[aux_reg_output_data
278 + (i ? reg_output_stride * i
279 : 0) // TODO: Xbyak should allow 0 scale
280 + sizeof(float) * jcp.oc_block * j];
281 default:
282 const size_t off = get_output_offset(i, j);
283 return make_safe_addr(aux_reg_output_data, off, reg_long_offt);
284 }
285 };
286
287 auto init = [=]() {
288 Label init_done, init_zero;
289
290 if (jcp.with_bias
291 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
292 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
293 jz(init_zero, T_NEAR);
294
295 for (int i = 0; i < load_loop_blk; i++) {
296 for (int j = 0; j < ur; ++j) {
297 if (load_dim_tail > 0 && i == load_loop_blk - 1) {
298 Label load_bias_tail, load_bias_done;
299 cmp(reg_load_loop_work,
300 load_loop_blk * jcp.load_loop_iter_step);
301 jl(load_bias_tail);
302 vmovups(vreg_accum(load_loop_blk, i, j), bias_ptr(i));
303 jmp(load_bias_done);
304
305 L(load_bias_tail);
306 load_bytes(vreg_accum(load_loop_blk, i, j),
307 reg_bias_data, i * jcp.oc_block * sizeof(float),
308 load_dim_tail * sizeof(float));
309 L(load_bias_done);
310 } else {
311 vmovups(vreg_accum(load_loop_blk, i, j), bias_ptr(i));
312 }
313 }
314 }
315 jmp(init_done);
316 }
317
318 L(init_zero);
319 for (int i = 0; i < load_loop_blk; ++i)
320 for (int j = 0; j < ur; ++j) {
321 auto r = vreg_accum(load_loop_blk, i, j);
322 vxorps(r, r, r);
323 }
324
325 L(init_done);
326 for (int i = 0; i < load_loop_blk; ++i) {
327 if (jcp.prop_kind == backward_weights && load_dim_tail > 0
328 && i == load_loop_blk - 1) {
329 Label load_init_tail, load_init_done;
330 cmp(reg_load_loop_work,
331 load_loop_blk * jcp.load_loop_iter_step);
332 jl(load_init_tail);
333 vmovups(vreg_load(i), load_ptr(0, i));
334 jmp(load_init_done);
335
336 L(load_init_tail);
337 vxorps(vreg_load(i), vreg_load(i), vreg_load(i));
338 load_bytes(vreg_load(i), aux_reg_load_data,
339 get_load_offset_bwd_w(0, i),
340 load_dim_tail * sizeof(float));
341 L(load_init_done);
342 } else {
343 vmovups(vreg_load(i), load_ptr(0, i));
344 }
345 }
346 vbroadcastss(vreg_bcast, bcast_ptr(0, 0));
347 };
348
349 auto store = [=]() {
350 Label store_noadd;
351
352 if (!jcp.with_sum) {
353 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
354 jnz(store_noadd, T_NEAR);
355 }
356
357 for (int j = 0; j < ur; ++j)
358 for (int i = 0; i < load_loop_blk; ++i) {
359 auto r = vreg_accum(load_loop_blk, i, j);
360 if (jcp.with_sum && load_dim_tail > 0
361 && i == load_loop_blk - 1) {
362 Label sum_tail, sum_done;
363 cmp(reg_load_loop_work,
364 load_loop_blk * jcp.load_loop_iter_step);
365 jl(sum_tail);
366 vaddps(r, r, output_ptr(i, j));
367 jmp(sum_done);
368
369 L(sum_tail);
370 load_bytes(vtmp, aux_reg_output_data,
371 get_output_offset(i, j),
372 load_dim_tail * sizeof(float));
373 vaddps(r, r, vtmp);
374 L(sum_done);
375 } else {
376 vaddps(r, r, output_ptr(i, j));
377 }
378 }
379
380 L(store_noadd);
381
382 apply_postops(load_loop_blk, ur, load_dim_tail);
383
384 if (jcp.prop_kind == backward_weights && load_dim_tail > 0) {
385 push(aux_reg_bcast_data);
386 }
387
388 const auto is_padding = jcp.oc_without_padding != jcp.oc;
389 if (is_padding) uni_vxorps(vtmp, vtmp, vtmp);
390 for (int j = 0; j < ur; ++j)
391 for (int i = 0; i < load_loop_blk; ++i) {
392 if (load_dim_tail > 0 && i == load_loop_blk - 1) {
393 Label store_tail, store_done;
394 cmp(reg_load_loop_work,
395 load_loop_blk * jcp.load_loop_iter_step);
396 jl(store_tail);
397 vmovups(output_ptr(i, j), vreg_accum(load_loop_blk, i, j));
398 jmp(store_done);
399
400 L(store_tail);
401 if (jcp.prop_kind == backward_weights) {
402 if (i) {
403 xor_(reg_tmp, reg_tmp); // rdx
404 mov(reg_tmp_output_stride,
405 reg_output_stride); // rax
406 mov(reg_output_stride_scale, i);
407 imul(reg_output_stride_scale);
408 } else {
409 xor_(reg_tmp_output_stride, reg_tmp_output_stride);
410 }
411 lea(reg_tmp,
412 ptr[aux_reg_output_data
413 + reg_tmp_output_stride]);
414 vmovups(output_ptr(i, j),
415 vreg_accum(load_loop_blk, i, j));
416 } else {
417 if (is_padding && jcp.with_binary) {
418 vmovups(ptr[aux_reg_output_data
419 + get_output_offset(i, j)],
420 vtmp);
421 }
422 store_bytes(vreg_accum(load_loop_blk, i, j),
423 aux_reg_output_data, get_output_offset(i, j),
424 load_dim_tail * sizeof(float));
425 }
426 L(store_done);
427 } else {
428 vmovups(output_ptr(i, j), vreg_accum(load_loop_blk, i, j));
429 }
430 }
431
432 if (jcp.prop_kind == backward_weights && load_dim_tail > 0) {
433 pop(aux_reg_bcast_data);
434 }
435 };
436
437 auto fma_block = [=](bool last_block) {
438 const bool is_tail = reduce_dim_tail && last_block;
439 const int u_end = is_tail ? reduce_dim_tail : jcp.reduce_loop_unroll;
440 for (int u = 0; u < u_end; ++u) {
441 for (int j = 0; j < ur; ++j) {
442 for (int i = 0; i < load_loop_blk; ++i) {
443 if (jcp.isa == avx2)
444 vfmadd231ps(vreg_accum(load_loop_blk, i, j),
445 vreg_load(i), vreg_bcast);
446 else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
447 vmulps(vtmp, vreg_bcast, vreg_load(i));
448 vaddps(vreg_accum(load_loop_blk, i, j),
449 vreg_accum(load_loop_blk, i, j), vtmp);
450 }
451 if (j == ur - 1 && !(last_block && u == u_end - 1)) {
452 if (jcp.prop_kind == backward_weights
453 && load_dim_tail > 0
454 && i == load_loop_blk - 1) {
455 Label fma_load_tail, fma_load_done;
456 cmp(reg_load_loop_work,
457 load_loop_blk * jcp.load_loop_iter_step);
458 jl(fma_load_tail);
459 vmovups(vreg_load(i), load_ptr(u + 1, i));
460 jmp(fma_load_done);
461
462 L(fma_load_tail);
463 vxorps(vreg_load(i), vreg_load(i), vreg_load(i));
464 load_bytes(vreg_load(i), aux_reg_load_data,
465 get_load_offset_bwd_w(u + 1, i),
466 load_dim_tail * sizeof(float));
467 L(fma_load_done);
468 } else {
469 vmovups(vreg_load(i), load_ptr(u + 1, i));
470 }
471 }
472 }
473 if (j < ur - 1) vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1));
474 }
475 if (!last_block || u < u_end - 1)
476 vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0));
477 }
478 };
479
480 Label reduce_loop, reduce_loop_tail;
481
482 mov(aux_reg_load_data, reg_load_data);
483 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
484
485 init();
486
487 mov(reduce_loop_iter, reg_reduce_loop_work);
488 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
489 jle(reduce_loop_tail, T_NEAR);
490
491 L(reduce_loop);
492 {
493 fma_block(false);
494 safe_add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step, reg_long_offt);
495 add(aux_reg_load_data, jcp.reduce_loop_load_step);
496 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
497 jg(reduce_loop, T_NEAR);
498 }
499
500 L(reduce_loop_tail);
501 fma_block(true);
502
503 store();
504}
505
506void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) {
507 if (!jcp.with_bias || jcp.prop_kind != backward_weights) return;
508
509 Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
510 Label diff_bias_load;
511
512 auto diff_bias_ptr = [=](int i) {
513 return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)];
514 };
515
516 auto load_ptr = [=](int u, int i) {
517 return ptr[aux_reg_load_data
518 + (i * jcp.os + u) * jcp.oc_block * sizeof(float)];
519 };
520
521 auto diff_bias_reg = [=](int i) { return Ymm(i); };
522
523 mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
524 cmp(reg_diff_bias_data, 0);
525 je(diff_bias_loop_out, T_NEAR);
526
527 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
528 jz(diff_bias_load, T_NEAR);
529
530 for (int i = 0; i < load_loop_blk; ++i) {
531 auto r = diff_bias_reg(i);
532 vxorps(r, r, r);
533 }
534 jmp(diff_bias_init_out, T_NEAR);
535
536 L(diff_bias_load);
537 for (int i = 0; i < load_loop_blk; ++i)
538 vmovups(diff_bias_reg(i), diff_bias_ptr(i));
539
540 L(diff_bias_init_out);
541 mov(aux_reg_load_data, reg_load_data);
542 mov(reduce_loop_iter, reg_reduce_loop_work);
543 L(diff_bias_loop);
544 {
545 for (int u = 0; u < jcp.reduce_loop_unroll; ++u)
546 for (int i = 0; i < load_loop_blk; ++i)
547 vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i));
548 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
549 add(aux_reg_load_data, jcp.reduce_loop_load_step);
550 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
551 jnz(diff_bias_loop, T_NEAR);
552 }
553
554 for (int i = 0; i < load_loop_blk; i++)
555 vmovups(diff_bias_ptr(i), diff_bias_reg(i));
556 add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
557 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
558
559 L(diff_bias_loop_out);
560}
561
562void jit_avx2_1x1_conv_kernel_f32::generate() {
563 preamble();
564
565 sub(rsp, stack_space_needed);
566
567 if (jcp.with_binary) {
568 const auto zeroed_reg = r15;
569 xor_(zeroed_reg, zeroed_reg);
570 mov(ptr[rsp + reg_binary_post_op_acc_off], zeroed_reg);
571 mov(ptr[rsp + reg_abi_param1_backup], abi_param1);
572 }
573
574 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
575 mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
576 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
577 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
578 if (jcp.with_bias) {
579 if (jcp.prop_kind == backward_weights) {
580 mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
581 mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
582 } else
583 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
584 }
585
586 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
587 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
588 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
589 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
590 if (jcp.prop_kind == backward_weights)
591 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
592
593 auto generate_load_loop_body = [=](int load_loop_blk) {
594 generate_bcast_loop(load_loop_blk);
595 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
596 switch (jcp.prop_kind) {
597 case forward_training:
598 case forward_inference:
599 add(reg_bias_data,
600 load_loop_blk * jcp.oc_block * sizeof(float));
601 safe_add(reg_output_data,
602 get_load_loop_output_fwd_offset(jcp, load_loop_blk),
603 reg_long_offt);
604 if (jcp.with_binary) {
605 mov(aux_reg_load_data,
606 ptr[rsp + reg_binary_post_op_acc_off]);
607 add(aux_reg_load_data, jcp.load_block * load_loop_blk);
608 mov(ptr[rsp + reg_binary_post_op_acc_off],
609 aux_reg_load_data);
610 }
611 break;
612 case backward_data:
613 safe_add(reg_output_data,
614 get_load_loop_output_bwd_d_offset(jcp, load_loop_blk),
615 reg_long_offt);
616 break;
617 case backward_weights:
618 for (int i = 0; i < load_loop_blk; i++)
619 add(reg_output_data, reg_output_stride);
620 break;
621 default: assert(!"invalid prop_kind");
622 }
623 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
624 };
625
626 Label load_loop_blk_8;
627 Label load_loop_blk_16;
628 Label load_loop_blk_24;
629 Label load_loop_blk_end;
630
631 cmp(reg_load_loop_work, 8);
632 jle(load_loop_blk_8, T_NEAR);
633
634 cmp(reg_load_loop_work, 32);
635 je(load_loop_blk_16, T_NEAR);
636
637 cmp(reg_load_loop_work, 16);
638 jle(load_loop_blk_16, T_NEAR);
639
640 L(load_loop_blk_24);
641 {
642 generate_diff_bias_loop(3);
643 generate_load_loop_body(3);
644 cmp(reg_load_loop_work, 32);
645 je(load_loop_blk_16);
646 cmp(reg_load_loop_work, 24);
647 jge(load_loop_blk_24);
648 }
649
650 cmp(reg_load_loop_work, 8);
651 jle(load_loop_blk_8, T_NEAR);
652
653 L(load_loop_blk_16);
654 {
655 generate_diff_bias_loop(2);
656 generate_load_loop_body(2);
657 cmp(reg_load_loop_work, 16);
658 jge(load_loop_blk_16);
659 }
660
661 L(load_loop_blk_8);
662 {
663 cmp(reg_load_loop_work, 0);
664 jle(load_loop_blk_end, T_NEAR);
665 generate_diff_bias_loop(1);
666 generate_load_loop_body(1);
667 }
668
669 L(load_loop_blk_end);
670
671 add(rsp, stack_space_needed);
672
673 postamble();
674
675 if (jcp.with_eltwise) postops_injector_->prepare_table();
676}
677
678status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
679 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
680 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
681 const primitive_attr_t &attr) {
682 if (!mayiuse(avx)) return status::unimplemented;
683 jcp.isa = mayiuse(avx2) ? avx2 : avx;
684
685 // TODO (Roma): this code is duplicated from the generic kernel; maybe the
686 // configuration struct could do some stuff below
687 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
688 const int ndims = src_d.ndims();
689
690 jcp.nthr = dnnl_get_max_threads();
691
692 jcp.prop_kind = cd.prop_kind;
693
694 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
695 jcp.mb = src_d.dims()[0];
696
697 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
698 jcp.oc = jcp.oc_without_padding;
699 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
700 jcp.ic = jcp.ic_without_padding;
701
702 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
703 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
704 jcp.iw = src_d.dims()[ndims - 1];
705 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
706 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
707 jcp.ow = dst_d.dims()[ndims - 1];
708
709 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
710 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
711 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
712
713 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
714 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
715 jcp.l_pad = cd.padding[0][ndims - 3];
716
717 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
718 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
719 jcp.stride_w = cd.strides[ndims - 3];
720
721 jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
722 format_kind::undef, cd.diff_bias_desc.format_kind)
723 != format_kind::undef;
724
725 jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow;
726 jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw;
727
728 jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
729 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
730
731 const auto &post_ops = attr.post_ops_;
732 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
733 jcp.with_dw_conv = dw_conv_ind != -1;
734
735 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
736 // handled in depthwise convolution.
737 const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind);
738 jcp.with_sum = sum_ind != -1;
739 const int eltwise_ind
740 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
741 jcp.with_eltwise = eltwise_ind != -1;
742 const int binary_ind
743 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
744 jcp.with_binary = binary_ind != -1;
745
746 if (dw_conv_ind >= 0) {
747 // dw_conv and post_ops after it are handled externally, so skip them
748 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
749 post_ops.entry_.cbegin() + dw_conv_ind);
750 } else {
751 jcp.post_ops = post_ops;
752 }
753
754 const auto dat_tag_nxc = utils::pick(ndims - 3, nwc, nhwc, ndhwc);
755 const auto dat_tag_nCx8c = utils::pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
756 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
757 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
758 const bool is_data_layout_nxc
759 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
760 const auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
761
762 const int is_bwd_d = jcp.prop_kind == backward_data;
763 format_tag_t wei_tag = with_groups
764 ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i,
765 gOIhw8i8o, gOIdhw8o8i, gOIhw8i8o, gOIdhw8o8i)
766 : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
767 OIhw8o8i, OIdhw8i8o, OIdhw8o8i);
768 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
769
770 const int simd_w = 8;
771
772 bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1;
773 if (ok_to_pad_channels) {
774 jcp.oc = rnd_up(jcp.oc, simd_w);
775 jcp.ic = rnd_up(jcp.ic, simd_w);
776 }
777
778 if (jcp.with_eltwise || jcp.with_binary)
779 if (jcp.isa < avx2) return status::unimplemented;
780
781 using namespace injector;
782 static constexpr bool sum_at_pos_0_only = true;
783 static constexpr bool sum_requires_scale_one = true;
784 static constexpr bool sum_requires_zp_zero = true;
785 const bool post_ops_ok_ = post_ops_ok({avx2, {eltwise, binary, sum},
786 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
787 sum_requires_zp_zero});
788 if (!post_ops_ok_) return status::unimplemented;
789
790 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag
791 && jcp.wei_tag == wei_tag && jcp.dst_tag == dat_tag;
792 if (!args_ok) return status::unimplemented;
793
794 args_ok = true && jcp.id == jcp.od && jcp.ih == jcp.oh && jcp.iw == jcp.ow
795 && IMPLICATION(!is_data_layout_nxc,
796 jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0)
797 && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0
798 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1
799 && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1;
800 if (!args_ok) return status::unimplemented;
801
802 // TODO: remove this restriction
803 // optimized 1x1 bwd_w does not support Intel AVX
804 if (jcp.prop_kind == backward_weights && jcp.isa != avx2)
805 return status::unimplemented;
806
807 jcp.ic_block = jcp.oc_block = simd_w;
808
809 jcp.ur = jcp.isa == avx2 ? 4 : 3; // Intel AVX support
810 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
811
812 int load_blocking {0};
813 int load_blocking_max {0};
814 int bcast_blocking {0};
815 int bcast_blocking_max {0};
816 int reduce_blocking {0};
817 int reduce_blocking_max {0};
818
819 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
820 jcp.reduce_dim = jcp.ic;
821 jcp.reduce_block = jcp.ic_block;
822
823 jcp.load_dim = jcp.oc;
824 jcp.load_block = jcp.oc_block;
825
826 jcp.bcast_dim = jcp.is;
827 jcp.bcast_block = jcp.ur;
828
829 jcp.reduce_loop_unroll = jcp.reduce_block;
830 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll
831 * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float);
832 jcp.reduce_loop_load_step
833 = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
834
835 jcp.bcast_loop_output_step = jcp.ur
836 * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float);
837 jcp.bcast_loop_output_substep = -1; // unused
838 jcp.bcast_loop_bcast_step = jcp.ur
839 * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float);
840 jcp.bcast_loop_bcast_substep = -1; // unused
841
842 jcp.load_loop_load_step
843 = rnd_up(jcp.ic, jcp.ic_block) * jcp.oc_block * sizeof(float);
844 jcp.load_loop_iter_step = jcp.oc_block;
845
846 load_blocking = is_data_layout_nxc
847 ? jcp.load_dim
848 : 120; // assumes the kernel is jcp.ur x 3
849 load_blocking_max = is_data_layout_nxc ? jcp.load_dim : 144;
850 bcast_blocking = 128; // affects load balancing across threads
851 bcast_blocking_max = 192;
852 reduce_blocking = is_data_layout_nxc ? jcp.reduce_dim
853 : 128; // affects L1$ utilization
854 } else if (jcp.prop_kind == backward_data) {
855 jcp.reduce_dim = jcp.oc;
856 jcp.reduce_block = jcp.oc_block;
857
858 jcp.load_dim = jcp.ic;
859 jcp.load_block = jcp.ic_block;
860
861 jcp.bcast_dim = jcp.os;
862 jcp.bcast_block = jcp.ur;
863
864 jcp.reduce_loop_unroll = jcp.reduce_block;
865 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll
866 * (is_data_layout_nxc ? 1 : jcp.os) * sizeof(float);
867 jcp.reduce_loop_load_step = jcp.reduce_loop_unroll
868 * rnd_up(jcp.ic, jcp.ic_block) * sizeof(float);
869
870 jcp.bcast_loop_output_step = jcp.ur
871 * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float);
872 jcp.bcast_loop_output_substep = -1; // unused
873 jcp.bcast_loop_bcast_step = jcp.ur
874 * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float);
875 jcp.bcast_loop_bcast_substep = -1; // unused
876
877 jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
878 jcp.load_loop_iter_step = jcp.ic_block;
879
880 load_blocking = is_data_layout_nxc
881 ? jcp.load_dim
882 : 96; // assumes the kernel is jcp.ur x 3
883 load_blocking_max = is_data_layout_nxc ? jcp.load_dim : 144;
884
885 bcast_blocking = 128; // affects load balancing across threads
886 bcast_blocking_max = 196;
887 reduce_blocking = is_data_layout_nxc ? jcp.reduce_dim
888 : 64; // affects L1$ utilization
889 } else if (jcp.prop_kind == backward_weights) {
890 jcp.reduce_dim = jcp.os;
891 jcp.reduce_block = 1;
892
893 jcp.load_dim = jcp.oc;
894 jcp.load_block = jcp.oc_block;
895
896 jcp.bcast_dim = jcp.ic;
897 jcp.bcast_block = jcp.ic_block;
898
899 jcp.reduce_loop_unroll = jcp.reduce_block;
900 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll
901 * (is_data_layout_nxc ? jcp.ic : jcp.ic_block) * sizeof(float);
902 jcp.reduce_loop_load_step = jcp.reduce_loop_unroll
903 * (is_data_layout_nxc ? jcp.oc : jcp.oc_block) * sizeof(float);
904
905 jcp.bcast_loop_output_step
906 = jcp.oc_block * jcp.ic_block * sizeof(float);
907 jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
908 jcp.bcast_loop_bcast_step = jcp.ic_block
909 * (is_data_layout_nxc ? 1 : jcp.is) * sizeof(float);
910 jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
911
912 jcp.load_loop_load_step = jcp.oc_block
913 * (is_data_layout_nxc ? 1 : jcp.os) * sizeof(float);
914 jcp.load_loop_iter_step = jcp.oc_block;
915
916 /* --- */
917
918 load_blocking = div_up(jcp.load_dim, jcp.load_block);
919 const bool no_load_tail = jcp.load_dim % jcp.load_block == 0;
920 const bool modify_load_blocking
921 = IMPLICATION(is_data_layout_nxc, no_load_tail);
922 while (modify_load_blocking) {
923 if (load_blocking <= 32)
924 break;
925 else if (load_blocking % 2 == 0)
926 load_blocking /= 2;
927 else if (load_blocking % 3 == 0)
928 load_blocking /= 3;
929 else
930 break;
931 }
932 load_blocking *= jcp.load_block;
933 load_blocking_max = load_blocking;
934 assert(IMPLICATION(
935 !is_data_layout_nxc, jcp.load_dim % load_blocking == 0));
936
937 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
938 const int bcast_blocking_lim = is_data_layout_nxc ? 17 : 9;
939 const bool no_bcast_tail = jcp.bcast_dim % jcp.bcast_block == 0;
940 const bool small_size_for_bcast
941 = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw <= 1024;
942
943 // TODO Verify if the size limitation helps for blocked format as well
944 const bool modify_bcast_blocking = IMPLICATION(
945 is_data_layout_nxc, no_bcast_tail && small_size_for_bcast);
946
947 while (modify_bcast_blocking) {
948 if (bcast_blocking <= bcast_blocking_lim)
949 break;
950 else if (bcast_blocking % 2 == 0)
951 bcast_blocking /= 2;
952 else if (bcast_blocking % 3 == 0)
953 bcast_blocking /= 3;
954 else
955 break;
956 }
957 bcast_blocking *= jcp.bcast_block;
958 bcast_blocking_max = bcast_blocking;
959 assert(IMPLICATION(
960 !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0));
961
962 reduce_blocking = is_data_layout_nxc
963 ? rnd_up(nstl::min(jcp.ow, 128), jcp.reduce_block)
964 : 128; // affects L1$ utilization
965 reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
966 } else
967 return status::unimplemented;
968
969 assert(load_blocking);
970 assert(load_blocking_max);
971 assert(bcast_blocking);
972 assert(bcast_blocking_max);
973 assert(reduce_blocking);
974
975 assert(jcp.bcast_block % jcp.ur == 0);
976 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.bcast_block;
977
978 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
979 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
980 jcp.nb_load_blocking = div_up(load_blocking, jcp.load_block);
981 jcp.nb_load_blocking_max = div_up(load_blocking_max, jcp.load_block);
982 jcp.nb_reduce_blocking = div_up(reduce_blocking, jcp.reduce_block);
983 jcp.nb_reduce_blocking_max = div_up(reduce_blocking_max, jcp.reduce_block);
984
985 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
986 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
987 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
988
989 if (jcp.prop_kind == backward_weights) {
990 const auto mb_with_nb_reduce
991 = static_cast<dim_t>(jcp.mb) * jcp.nb_reduce;
992 // prevent too large argument to cpu reducer
993 if (mb_with_nb_reduce > std::numeric_limits<int>::max())
994 return status::unimplemented;
995 }
996
997 return status::success;
998}
999
1000void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
1001 memory_tracking::registrar_t &scratchpad,
1002 const jit_1x1_conv_conf_t &jcp) {
1003 using namespace dnnl::impl::memory_tracking::names;
1004
1005 if (jcp.with_bias && jcp.prop_kind != backward_data
1006 && (jcp.oc != jcp.oc_without_padding // blocked format
1007 || (jcp.prop_kind == backward_weights // nxc format
1008 && jcp.oc % jcp.oc_block != 0))) {
1009 const size_t nelems_padded_bias
1010 = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block);
1011 scratchpad.book<float>(key_conv_padded_bias, nelems_padded_bias);
1012 }
1013}
1014
1015} // namespace x64
1016} // namespace cpu
1017} // namespace impl
1018} // namespace dnnl
1019