1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_1X1_CONV_UTILS_HPP
18#define CPU_X64_JIT_UNI_1X1_CONV_UTILS_HPP
19
20#include "common/convolution_pd.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/nstl.hpp"
24#include "common/primitive_desc_iterator.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36struct reduce_to_unit_stride_t {
37 convolution_desc_t conv_d_;
38 bool reduce_src_;
39 size_t space_per_thread_;
40};
41
42/* 1x1-kernel does not support non-unit strides so far, so the idea is:
43 * - for fwd or bwd_weights: to copy src to a scratch memory (with strides
44 * equal to 1) and then call the kernel
45 * - for bwd_data: reduce the problem to the one with unit stride by
46 * performing computations in a scratch memory (with strides equal to 1)
47 * and then copy the result to diff_src */
48template <typename conv_pd_t>
49inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
50 const memory_desc_t *&src_d, const memory_desc_t *dst_d,
51 const memory_desc_t *weights_d) {
52 const int ndims = src_d->ndims;
53
54 const bool with_groups
55 = memory_desc_wrapper(weights_d).ndims() == ndims + 1;
56
57 bool rtus_applicable = utils::one_of(ndims, 3, 4)
58 && IMPLICATION(with_groups, weights_d->dims[0] == 1);
59 if (ndims == 3)
60 rtus_applicable = rtus_applicable && conv_d->strides[0] != 1
61 && conv_d->src_desc.data_type != data_type::s32;
62 else
63 rtus_applicable = rtus_applicable
64 && (conv_d->strides[0] != 1 || conv_d->strides[1] != 1);
65 for (int d = 2; d < ndims; ++d) {
66 /* TODO: relax these conditions (by improving reducer) */
67 rtus_applicable = rtus_applicable && conv_d->padding[0][d - 2] == 0
68 && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
69 }
70 if (!rtus_applicable) return;
71
72 const auto dat_tag = ndims == 3
73 ? memory_desc_wrapper(src_d).matches_one_of_tag(
74 format_tag::nCw8c, format_tag::nCw16c, format_tag::nwc)
75 : memory_desc_wrapper(src_d).matches_one_of_tag(
76 format_tag::nChw8c, format_tag::nChw16c, format_tag::nhwc);
77 if (dat_tag == format_tag::undef) return;
78
79 const bool is_nspc
80 = utils::one_of(dat_tag, format_tag::nwc, format_tag::nhwc);
81 if (is_nspc && !mayiuse(sse41)) return;
82
83 // rtus is applicable, configure it.
84 self->rtus_.reduce_src_ = true;
85 conv_d = &(self->rtus_.conv_d_ = *conv_d);
86 self->rtus_.conv_d_.strides[0] = 1;
87 if (ndims == 4) self->rtus_.conv_d_.strides[1] = 1;
88 utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
89 if (ndims == 4) utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
90 const int ic = src_d->dims[1];
91 if (self->desc()->prop_kind == prop_kind::backward_data) {
92 data_type_t data_type = self->rtus_.conv_d_.diff_src_desc.data_type;
93 src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
94 self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
95 self->rtus_.conv_d_.diff_src_desc.data_type = data_type;
96 memory_desc_wrapper::compute_blocking(
97 self->rtus_.conv_d_.diff_src_desc, dat_tag);
98 } else {
99 data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
100 src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
101 self->rtus_.conv_d_.src_desc.dims[1] = ic;
102 self->rtus_.conv_d_.src_desc.data_type = data_type;
103 memory_desc_wrapper::compute_blocking(
104 self->rtus_.conv_d_.src_desc, dat_tag);
105 }
106}
107
108template <typename conv_pd_t>
109inline void rtus_prepare_space_info(conv_pd_t *self,
110 memory_tracking::registrar_t &scratchpad, int max_threads) {
111 if (!self->rtus_.reduce_src_) return;
112 const auto &jcp = self->jcp_;
113 const bool is_nspc
114 = utils::one_of(jcp.src_tag, format_tag::nhwc, format_tag::nwc);
115
116 const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
117 jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
118 size_t typesize
119 = types::data_type_size(self->invariant_src_md()->data_type);
120
121 self->rtus_.space_per_thread_
122 = is_nspc ? jcp.is * jcp.ic : factor * jcp.is * jcp.ic_block;
123 scratchpad.book(memory_tracking::names::key_conv_rtus_space,
124 max_threads * self->rtus_.space_per_thread_, typesize);
125}
126
127template <cpu_isa_t isa>
128struct rtus_driver_t : public jit_generator {
129
130 struct call_params_t {
131 const void *ws; /* reduced image (w/ strides = 1) */
132 const void *src; /* source image (w/ non-unit strides) */
133 size_t icb;
134 size_t os;
135 size_t iw_start;
136 };
137
138 DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
139
140 Xbyak::Reg64 reg_ws = r12;
141 Xbyak::Reg64 reg_src = r13;
142 Xbyak::Reg64 reg_icb = rdx;
143 Xbyak::Reg64 reg_os = r11;
144 Xbyak::Reg64 reg_iw_start = r8;
145
146 Xbyak::Reg64 reg_cur_os = rax;
147 Xbyak::Reg64 reg_cur_iw = r9;
148 Xbyak::Reg64 reg_cur_src = r10;
149 Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */
150
151 Xbyak::Opmask tail_mask = k2;
152
153 // nspc section
154 Xbyak::Reg64 reg_cur_icb = rax;
155 Xbyak::Reg64 reg_tail_mask = r14;
156 Xbyak::Reg64 reg_icb_remainder = rcx;
157 Xbyak::Reg64 reg_ws_copy = r15;
158
159 int iw_, stride_w_;
160 int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
161 bool src_to_ws_;
162 size_t typesize_;
163 int ic_, ic_tail_;
164 bool is_nspc_;
165
166 Xbyak::Xmm reg_zero;
167 Xbyak::Xmm reg_v;
168
169 rtus_driver_t(int iw, int stride_w, int src_step_h, int src_step_icb,
170 int ws_step_icb, bool src_to_ws, size_t typesize, int ic,
171 bool is_nspc = false)
172 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
173 , iw_(iw)
174 , stride_w_(stride_w)
175 , src_step_h_(src_step_h)
176 , src_step_icb_(src_step_icb)
177 , ws_step_icb_(ws_step_icb)
178 , src_to_ws_(src_to_ws)
179 , typesize_(typesize)
180 , ic_(ic)
181 , is_nspc_(is_nspc) {
182 using namespace Xbyak;
183
184 assert(ic_ > 0);
185
186 /*FIXME: derive Vmm type on compile time.
187 * changing register type on runtime
188 * seems dangerous,and some xbyak functions might
189 * fail to work on reg_v, reg_zero because of this
190 * data_type change, e.g. uni_vpxor doen't
191 * work on reg_zero now*/
192 auto Vmm = [=](int idx, size_t typesize) {
193 Xmm res;
194 if (is_nspc_) {
195 switch (isa) {
196 case sse41: res = Xmm(idx); break;
197 case avx2: res = Ymm(idx); break;
198 case avx512_core: res = Zmm(idx); break;
199 default: assert(!"Not supported isa"); res = Xmm(idx);
200 }
201 return res;
202 }
203 switch (isa) {
204 case sse41:
205 switch (typesize) {
206 case 2: res = Xmm(idx); break;
207 default: assert(!"Not supported typesize");
208 }
209 break;
210 case avx2:
211 switch (typesize) {
212 case 4: res = Ymm(idx); break;
213 case 2: res = Xmm(idx); break;
214 default:
215 assert(!"Not supported typesize");
216 res = Ymm(idx);
217 }
218 break;
219 case avx512_core:
220 switch (typesize) {
221 case 4: res = Zmm(idx); break;
222 case 2: res = Ymm(idx); break;
223 case 1: res = Xmm(idx); break;
224 default:
225 assert(!"Not supported typesize");
226 res = Zmm(idx);
227 }
228 }
229 return res;
230 };
231
232 reg_zero = Vmm(0, typesize);
233 reg_v = Vmm(1, typesize);
234
235 vlen_ = reg_v.getBit() / 8;
236 vlen_shift_ = 0;
237
238 int tvlen = is_nspc_ ? typesize_ : vlen_;
239 while (tvlen > 1) {
240 tvlen /= 2;
241 vlen_shift_++;
242 }
243
244 const int simd_w = vlen_ / sizeof(float);
245 ic_tail_ = ic_ % simd_w;
246 }
247
248 void loop_is() {
249 using namespace Xbyak;
250
251 mov(reg_cur_src, reg_src);
252 mov(reg_cur_iw, reg_iw_start);
253 mov(reg_cur_os, reg_os);
254
255 Label is_loop;
256 L(is_loop);
257
258 if (src_to_ws_) {
259 vmovups(reg_v, ptr[reg_cur_src]);
260 vmovups(ptr[reg_ws], reg_v);
261 } else {
262 vmovups(reg_v, ptr[reg_ws]);
263 vmovups(ptr[reg_cur_src], reg_v);
264 for (int w = 1; w < stride_w_; ++w)
265 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
266 }
267
268 add(reg_ws, vlen_);
269 add(reg_cur_src, stride_w_ * vlen_);
270
271 // for 1d or stride_h=1 convolutions the loop over h should be skipped
272 if (!(src_step_icb_ == iw_ || src_step_h_ == iw_)) {
273 Label skip_h_step;
274 add(reg_cur_iw, stride_w_);
275 cmp(reg_cur_iw, iw_);
276 jl(skip_h_step, T_NEAR);
277
278 if (src_to_ws_) {
279 add(reg_cur_src, (src_step_h_ - iw_) * vlen_);
280 } else {
281 mov(reg_cur_src_fin, reg_cur_src);
282 add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_);
283 Label ih_loop;
284 L(ih_loop);
285
286 for (int w = 0; w < stride_w_; ++w)
287 vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
288
289 add(reg_cur_src, stride_w_ * vlen_);
290 cmp(reg_cur_src, reg_cur_src_fin);
291 jl(ih_loop, T_NEAR);
292 }
293 xor_(reg_cur_iw, reg_cur_iw);
294 L(skip_h_step);
295 }
296
297 sub(reg_cur_os, vlen_);
298 jnz(is_loop, T_NEAR);
299
300 /* restore dst */
301 sub(reg_ws, reg_os);
302 }
303
304 void loop_is_nspc() {
305 using namespace Xbyak;
306
307 assert(is_nspc_);
308
309 mov(reg_cur_src, reg_src);
310 mov(reg_cur_iw, reg_iw_start);
311
312 if (isa == avx512_core) {
313 push(rcx); // preserve rcx, used for shift
314 mov(reg_icb_remainder, reg_icb);
315 and_(reg_icb_remainder,
316 (vlen_ / typesize_) - 1); // # of elements in tail
317 mov(reg_tail_mask, 1);
318 shl(reg_tail_mask, reg_icb_remainder.cvt8());
319 dec(reg_tail_mask);
320 pop(rcx);
321
322 switch (typesize_) {
323 case 4: kmovw(tail_mask, reg_tail_mask.cvt32()); break;
324 case 2: kmovd(tail_mask, reg_tail_mask.cvt32()); break;
325 case 1: kmovq(tail_mask, reg_tail_mask); break;
326 default: assert(!"Unsupported typesize");
327 }
328 }
329
330 auto load_reg = [=](const Xmm &vreg, const Reg64 &reg,
331 const int64_t offset, const int load_size) {
332 if (isa == avx512_core) {
333 const Address &addr = ptr[reg + offset];
334 switch (typesize_) {
335 case 4: vmovups(vreg, addr); break;
336 case 2: vmovdqu16(vreg, addr); break;
337 case 1: vmovdqu8(vreg, addr); break;
338 default: assert(!"Unsupported typesize");
339 }
340 } else {
341 // FIXME: figure out a better way for compile-time definition
342 // of xmm/ymm registers
343 const bool is_ymm = load_size > 16;
344 if (is_ymm)
345 load_bytes(Ymm(vreg.getIdx()), reg, offset, load_size);
346 else
347 load_bytes(vreg, reg, offset, load_size);
348 }
349 };
350
351 auto store_reg = [=](const Reg64 &reg, const Xmm &vreg,
352 const int64_t offset, const int store_size) {
353 if (isa == avx512_core) {
354 const Address &addr = ptr[reg + offset];
355 switch (typesize_) {
356 case 4: vmovups(addr, vreg); break;
357 case 2: vmovdqu16(addr, vreg); break;
358 case 1: vmovdqu8(addr, vreg); break;
359 default: assert(!"Unsupported typesize");
360 }
361 } else {
362 // FIXME: figure out a better way for compile-time definition
363 // of xmm/ymm registers
364 const bool is_ymm = store_size > 16;
365 if (is_ymm)
366 store_bytes(Ymm(vreg.getIdx()), reg, offset, store_size);
367 else
368 store_bytes(vreg, reg, offset, store_size);
369 }
370 };
371
372 mov(reg_ws_copy, reg_ws);
373 shl(reg_icb, vlen_shift_);
374
375 const size_t w_step_factor = ic_ * typesize_;
376 const size_t max_load_store_bytes = isa == sse41
377 ? typesize_ == 4 ? 16 : 8
378 : typesize_ == 4 ? 32 : 16;
379 const size_t load_store_size
380 = isa == avx512_core ? vlen_ : max_load_store_bytes;
381 size_t load_store_tail_size = (typesize_ == 1 ? max_load_store_bytes
382 : ic_tail_ * typesize_);
383
384 Label is_loop, ic_loop, ic_loop_tail, ic_loop_finish;
385 L(is_loop);
386 {
387 mov(reg_cur_src, reg_src);
388 mov(reg_ws, reg_ws_copy);
389 mov(reg_cur_icb, reg_icb);
390
391 L(ic_loop);
392 {
393 cmp(reg_cur_icb, load_store_size);
394 jl(ic_loop_tail, T_NEAR);
395
396 if (src_to_ws_) {
397 load_reg(reg_v, reg_cur_src, 0, load_store_size);
398 store_reg(reg_ws, reg_v, 0, load_store_size);
399 } else {
400 load_reg(reg_v, reg_ws, 0, load_store_size);
401 store_reg(reg_cur_src, reg_v, 0, load_store_size);
402 for (int w = 1; w < stride_w_; ++w)
403 store_reg(reg_cur_src, reg_zero, w * w_step_factor,
404 load_store_size);
405 }
406 add(reg_ws, load_store_size);
407 add(reg_cur_src, load_store_size);
408
409 sub(reg_cur_icb, load_store_size);
410 jmp(ic_loop, T_NEAR);
411 }
412
413 L(ic_loop_tail);
414 {
415 cmp(reg_cur_icb, 0);
416 je(ic_loop_finish, T_NEAR);
417
418 if (src_to_ws_) {
419 load_reg(reg_v | tail_mask, reg_cur_src, 0,
420 load_store_tail_size);
421 store_reg(
422 reg_ws, reg_v | tail_mask, 0, load_store_tail_size);
423 } else {
424 load_reg(
425 reg_v | tail_mask, reg_ws, 0, load_store_tail_size);
426 store_reg(reg_cur_src, reg_v | tail_mask, 0,
427 load_store_tail_size);
428 for (int w = 1; w < stride_w_; ++w)
429 store_reg(reg_cur_src, reg_zero | tail_mask,
430 w * w_step_factor, load_store_tail_size);
431 }
432 }
433 L(ic_loop_finish);
434
435 add(reg_ws_copy, w_step_factor);
436 add(reg_src, stride_w_ * w_step_factor);
437
438 // for 1d or stride_h=1 convolutions the loop over h should be skipped
439 const bool skip_oh_step = src_step_h_ == iw_;
440 if (!skip_oh_step) {
441 mov(reg_cur_src, reg_src);
442 Label skip_h_step;
443 add(reg_cur_iw, stride_w_);
444 cmp(reg_cur_iw, iw_);
445 jl(skip_h_step, T_NEAR);
446
447 if (src_to_ws_) {
448 add(reg_src, (src_step_h_ - iw_) * w_step_factor);
449 } else {
450 mov(reg_cur_src_fin, reg_cur_src);
451 add(reg_cur_src_fin, (src_step_h_ - iw_) * w_step_factor);
452 Label ih_loop_nhwc, ic_ih_loop_nhwc, ic_tail_ih_loop_nhwc,
453 ic_finish_ih_loop_nhwc;
454 L(ih_loop_nhwc);
455 mov(reg_cur_src, reg_src);
456 mov(reg_cur_icb, reg_icb);
457 L(ic_ih_loop_nhwc);
458 cmp(reg_cur_icb, load_store_size);
459 jl(ic_tail_ih_loop_nhwc, T_NEAR);
460
461 for (int w = 0; w < stride_w_; ++w)
462 store_reg(reg_cur_src, reg_zero, w * w_step_factor,
463 load_store_size);
464
465 add(reg_cur_src, load_store_size);
466 sub(reg_cur_icb, load_store_size);
467 jnz(ic_ih_loop_nhwc, T_NEAR);
468
469 L(ic_tail_ih_loop_nhwc);
470 cmp(reg_cur_icb, 0);
471 jle(ic_finish_ih_loop_nhwc, T_NEAR);
472
473 for (int w = 0; w < stride_w_; ++w)
474 store_reg(reg_cur_src, reg_zero | tail_mask,
475 w * w_step_factor, load_store_tail_size);
476
477 L(ic_finish_ih_loop_nhwc);
478
479 add(reg_src, stride_w_ * w_step_factor);
480 cmp(reg_src, reg_cur_src_fin);
481 jl(ih_loop_nhwc, T_NEAR);
482 }
483 xor_(reg_cur_iw, reg_cur_iw);
484 L(skip_h_step);
485 }
486
487 sub(reg_os, 1);
488 jnz(is_loop, T_NEAR);
489 }
490 }
491
492 void generate() override {
493 using namespace Xbyak;
494 assert(utils::one_of(isa, sse41, avx2, avx512_core));
495
496 preamble();
497#define READ_PARAM(what) \
498 mov(reg_##what, ptr[abi_param1 + offsetof(call_params_t, what)])
499 READ_PARAM(src);
500 READ_PARAM(icb);
501 READ_PARAM(os);
502 READ_PARAM(iw_start);
503 READ_PARAM(ws);
504#undef READ_PARAM
505
506 if (!src_to_ws_) {
507 switch (reg_zero.getBit() / 8) {
508 case 16 /*xmm*/: uni_vpxor(reg_zero, reg_zero, reg_zero); break;
509 case 32 /*ymm*/: {
510 Xbyak::Ymm ymm_z(reg_zero.getIdx());
511 uni_vpxor(ymm_z, ymm_z, ymm_z);
512 break;
513 }
514 case 64 /*zmm*/: {
515 Xbyak::Zmm zmm_z(reg_zero.getIdx());
516 uni_vpxor(zmm_z, zmm_z, zmm_z);
517 break;
518 }
519 default: assert(!"rtus kernel failure");
520 }
521 }
522 if (is_nspc_) {
523 loop_is_nspc();
524 } else {
525 shl(reg_os, vlen_shift_);
526
527 Label icb_loop;
528 L(icb_loop);
529
530 loop_is();
531
532 add(reg_ws, ws_step_icb_ * vlen_);
533 add(reg_src, src_step_icb_ * vlen_);
534
535 sub(reg_icb, vlen_ / typesize_);
536 jnz(icb_loop, T_NEAR);
537 }
538
539 postamble();
540
541 uni_vzeroupper();
542 ret();
543 }
544};
545
546template <cpu_isa_t isa, typename conv_t>
547inline status_t init_rtus_driver(conv_t *self) {
548 const auto &conf = *self->pd();
549 if (!conf.rtus_.reduce_src_) return status::success;
550
551 const auto &cd = *conf.desc();
552 const int ndims = conf.ndims();
553 const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
554 const int stride_w = cd.strides[ndims - 3];
555
556 const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
557 const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md();
558
559 const int ih = ndims == 3 ? 1 : src_d.dims[2];
560 const int iw = src_d.dims[ndims - 1];
561 const int ic = src_d.dims[1];
562
563 const auto src_tag = memory_desc_wrapper(src_d).matches_one_of_tag(
564 format_tag::nhwc, format_tag::nwc);
565 const bool is_nspc = src_tag != format_tag::undef;
566 const int src_step_h = stride_h * iw;
567 const int src_step_icb = !is_nspc ? ih * iw : 1;
568 const int ws_step_icb = !is_nspc ? conf.jcp_.is : 1;
569 const bool src_to_ws = !is_bwd_data;
570 const size_t typesize
571 = types::data_type_size(self->pd()->invariant_src_md()->data_type);
572
573 CHECK(safe_ptr_assign(self->rtus_driver_,
574 new rtus_driver_t<isa>(iw, stride_w, src_step_h, src_step_icb,
575 ws_step_icb, src_to_ws, typesize, ic, is_nspc)));
576
577 return self->rtus_driver_->create_kernel();
578}
579
580inline int best_divider(int value, int min_divider, int max_divider,
581 bool find_max, int step = 1) {
582 using namespace dnnl::impl::utils;
583 max_divider = nstl::max(1, nstl::min(max_divider, value));
584 min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
585
586 auto loss_ratio = [](int total, int chunk) {
587 return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk);
588 };
589
590 float min_loss = FLT_MAX;
591 int x_divider = max_divider;
592 for (int divider = max_divider; divider >= min_divider; divider -= step) {
593 const float loss = loss_ratio(value, divider);
594 if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
595 min_loss = loss;
596 x_divider = divider;
597 }
598 }
599 return x_divider;
600}
601
602typedef jit_1x1_conv_conf_t jcp_t;
603
604inline bool is_bcast_layout_nxc(const jcp_t &jcp) {
605 switch (jcp.prop_kind) {
606 case prop_kind::forward_training:
607 case prop_kind::forward_inference:
608 case prop_kind::backward_weights:
609 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
610 format_tag::nhwc, format_tag::nwc);
611 case prop_kind::backward_data:
612 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
613 format_tag::nhwc, format_tag::nwc);
614 default: assert(!"invalid prop_kind"); return false;
615 }
616}
617
618inline bool is_load_layout_nxc(const jcp_t &jcp) {
619 return jcp.prop_kind == prop_kind::backward_weights
620 && utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
621 format_tag::nwc);
622}
623
624inline bool is_out_layout_nxc(const jcp_t &jcp) {
625 switch (jcp.prop_kind) {
626 case prop_kind::forward_training:
627 case prop_kind::forward_inference:
628 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
629 format_tag::nhwc, format_tag::nwc);
630 case prop_kind::backward_data:
631 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
632 format_tag::nhwc, format_tag::nwc);
633 case prop_kind::backward_weights: return false;
634 default: assert(!"invalid prop_kind"); return false;
635 }
636}
637
638inline size_t get_bcast_u_offset(const jcp_t &jcp) {
639 return is_bcast_layout_nxc(jcp) ? jcp.ic : jcp.ic_block;
640}
641
642inline size_t get_bcast_j_offset(const jcp_t &jcp) {
643 return is_bcast_layout_nxc(jcp) ? jcp.reduce_dim : jcp.reduce_loop_unroll;
644}
645
646inline size_t get_bcast_offset(const jcp_t &jcp, int u, int j) {
647 size_t offset;
648 if (utils::one_of(jcp.prop_kind, prop_kind::forward_training,
649 prop_kind::forward_inference, prop_kind::backward_data)) {
650 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
651 if (is_bcast_layout_nxc(jcp) || u != jcp.reduce_loop_unroll) {
652 offset = j * get_bcast_j_offset(jcp) + u;
653 } else {
654 offset = (jcp.bcast_dim + j) * get_bcast_j_offset(jcp);
655 }
656 } else {
657 offset = u * get_bcast_u_offset(jcp) + j;
658 }
659 return sizeof(float) * offset;
660}
661
662inline size_t get_load_u_offset(const jcp_t &jcp) {
663 return is_load_layout_nxc(jcp) ? jcp.oc : jcp.oc_block;
664}
665
666inline size_t get_load_i_offset(const jcp_t &jcp) {
667 return is_load_layout_nxc(jcp) ? jcp.oc_block : jcp.os;
668}
669
670inline size_t get_load_bwd_w_offset(const jcp_t &jcp, int i, int u0) {
671 if (is_load_layout_nxc(jcp)) {
672 return i * get_load_i_offset(jcp) + u0 * get_load_u_offset(jcp);
673 } else {
674 return (i * get_load_i_offset(jcp) + u0) * get_load_u_offset(jcp);
675 }
676}
677
678inline size_t get_output_i_offset(const jcp_t &jcp) {
679 if (is_out_layout_nxc(jcp)) {
680 return jcp.load_block;
681 } else {
682 return (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
683 }
684}
685
686inline size_t get_output_j_offset(const jcp_t &jcp) {
687 return is_out_layout_nxc(jcp) ? jcp.load_dim : jcp.load_block;
688}
689
690inline size_t get_load_loop_output_fwd_offset(
691 const jcp_t &jcp, int load_loop_blk) {
692 size_t offset = load_loop_blk * jcp.oc_block * sizeof(float);
693 if (!is_out_layout_nxc(jcp)) {
694 offset *= jcp.with_dw_conv ? jcp.ow : jcp.os;
695 }
696 return offset;
697}
698
699inline size_t get_load_loop_output_bwd_d_offset(
700 const jcp_t &jcp, int load_loop_blk) {
701 size_t offset = load_loop_blk * jcp.ic_block * sizeof(float);
702 if (!is_out_layout_nxc(jcp)) { offset *= jcp.os; }
703 return offset;
704}
705
706} // namespace x64
707} // namespace cpu
708} // namespace impl
709} // namespace dnnl
710
711#endif
712