1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <functional>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23#include "common/memory_tracking.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/cpu_batch_normalization_utils.hpp"
29#include "cpu/platform.hpp"
30#include "cpu/x64/cpu_barrier.hpp"
31#include "cpu/x64/jit_generator.hpp"
32
33#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
34#include "cpu/x64/jit_uni_batch_normalization.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace memory_tracking::names;
42
43using namespace Xbyak;
44namespace barrier = simple_barrier;
45
46using acc_data_t = float;
47
48namespace {
49dim_t get_c_padded(const batch_normalization_pd_t *pd) {
50 return pd->src_md()->padded_dims[1];
51}
52
53bool is_nspc(const memory_desc_wrapper &d) {
54 using namespace format_tag;
55 const bool is_nspc = d.matches_one_of_tag(nc, nwc, nhwc, ndhwc);
56 return is_nspc;
57}
58} // namespace
59
60struct jit_bnorm_conf_t {
61 // TODO: put all needed info here to avoid duplicate work and potentially
62 // diverging definitions of derived parameters
63 const batch_normalization_pd_t *pd_;
64
65 int simd_w_ {0};
66 size_t dt_size_ {0};
67 bool is_nspc_ {false};
68
69 // thread partition info
70 bool do_blocking_ {false};
71 bool is_spatial_thr_ {false};
72 dim_t C_blks_per_iter_ {0};
73 int C_nthr_ {0};
74 int N_nthr_ {0};
75 int S_nthr_ {0};
76 int64_t iters_ {0};
77 // C_blks and thread partition can change for last iteration
78 dim_t C_blks_last_iter_ {0};
79 int C_nthr_last_iter_ {0};
80 int N_nthr_last_iter_ {0};
81 int S_nthr_last_iter_ {0};
82
83 jit_bnorm_conf_t(const batch_normalization_pd_t *pd, int nthr, int simd_w)
84 : pd_(pd), simd_w_(simd_w) {
85
86 const dim_t N = pd_->MB();
87 const dim_t C_PADDED = get_c_padded(pd_);
88 const dim_t D = pd_->D();
89 const dim_t H = pd_->H();
90 const dim_t W = pd_->W();
91 const dim_t SP = D * H * W;
92
93 const memory_desc_wrapper src_d(pd_->src_md());
94 is_nspc_ = is_nspc(src_d);
95
96 dt_size_ = types::data_type_size(pd_->src_md()->data_type);
97 size_t data_size = dt_size_ * N * C_PADDED * SP;
98 const size_t l3_size = platform::get_per_core_cache_size(3) * nthr;
99 // TODO: cache balancing for nspc
100 const size_t l3_filling_factor = 4;
101 do_blocking_ = !is_nspc_ && data_size >= l3_size / l3_filling_factor;
102
103 // find thread partition over N, C_blks and SP
104
105 const dim_t C_blks = C_PADDED / simd_w_;
106
107 if (do_blocking_) {
108 const int num_tensors = pd_->is_fwd() ? 1 : 2;
109 const size_t working_set_size
110 = dt_size_ * (N * SP * simd_w_) * num_tensors;
111 bnorm_utils::cache_balance(working_set_size, C_blks, N, nthr,
112 C_blks_per_iter_, iters_);
113 C_blks_last_iter_ = C_blks - (iters_ - 1) * C_blks_per_iter_;
114 } else {
115 C_blks_per_iter_ = C_blks;
116 iters_ = 1;
117 }
118
119 is_spatial_thr_
120 = this->thread_partition(/* is_spatial_thr_ = */ true, nthr,
121 /* dimensions */
122 N, C_blks_per_iter_, SP,
123 /* outputs */
124 C_nthr_, N_nthr_, S_nthr_);
125
126 if (iters_ > 1)
127 this->thread_partition(is_spatial_thr_, nthr,
128 /* dimensions */
129 N, C_blks_last_iter_, SP,
130 /* outputs */
131 C_nthr_last_iter_, N_nthr_last_iter_, S_nthr_last_iter_);
132 }
133
134 // given nthr and shape of problem, choose the thread partition
135 // to use (ie set N_nthr, C_nthr, and S_nthr)
136 bool thread_partition(bool spatial_thr_allowed, int nthr, dim_t N,
137 dim_t C_blks, dim_t SP, int &C_nthr, int &N_nthr, int &S_nthr) {
138 if (((nthr <= C_blks) && IMPLICATION(is_nspc_, N == 1))
139 || !dnnl_thr_syncable()) {
140 C_nthr = nthr;
141 N_nthr = 1;
142 S_nthr = 1;
143 } else {
144 if (is_nspc_) {
145 if (C_blks <= 8)
146 C_nthr = 1;
147 else if (nthr >= 8 && C_blks <= 32)
148 C_nthr = 8;
149 else {
150 C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
151 // Unroll by channels in JIT kernel
152 if ((C_nthr == C_blks) || (C_nthr == nthr)) C_nthr = 1;
153 }
154 N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
155 // heuristic for training on avx512_core_amx
156 // TODO: test heuristic when global stats flag is set
157 if (!pd_->use_global_stats() && 0 < dt_size_ && 0 < simd_w_
158 && 1 < C_nthr && nthr <= N
159 && mayiuse(avx512_core_amx)) {
160 const size_t data_size
161 = dt_size_ * N * SP * C_blks * simd_w_;
162 const size_t C_split_data_size
163 = utils::div_up(data_size, N_nthr);
164 const size_t N_split_data_size
165 = utils::div_up(data_size, nthr);
166 const size_t l2_size_per_core
167 = platform::get_per_core_cache_size(2);
168 const size_t l3_size_per_core
169 = platform::get_per_core_cache_size(3);
170 const size_t cache_size_per_core
171 = l2_size_per_core + l3_size_per_core;
172 // if current split is too big for cache, better to split by N
173 const bool condition1
174 = cache_size_per_core < C_split_data_size;
175 // if split by N is also too big for cache, bwd is better off as it was
176 const bool condition2 = pd_->is_fwd()
177 || cache_size_per_core >= N_split_data_size;
178 if (condition1 && condition2) {
179 C_nthr = 1;
180 N_nthr = nthr;
181 }
182 }
183 S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
184 } else {
185 if (do_blocking_) {
186 N_nthr = (int)nstl::min<dim_t>(N, nthr);
187 C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr);
188 S_nthr = (int)nstl::min<dim_t>(
189 SP, nthr / (C_nthr * N_nthr));
190 } else {
191 C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
192 N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
193 S_nthr = (int)nstl::min<dim_t>(
194 SP, nthr / (C_nthr * N_nthr));
195 }
196 }
197
198 if (!spatial_thr_allowed) S_nthr = 1;
199
200 if (S_nthr < 1) S_nthr = 1;
201 }
202
203 // spatial_thr_allowed is meant to help maintain
204 // consistent decisions about spatial threading
205 // between mutiple invocations of this routine.
206 // It is caller's responsibility to check the
207 // return value and pass it as a flag to the
208 // next call if needed.
209 if (S_nthr == 1) spatial_thr_allowed = false;
210
211 return spatial_thr_allowed;
212 }
213};
214
215template <cpu_isa_t isa>
216struct jit_bnorm_t : public jit_generator {
217 struct call_params_t {
218 // keep all sizes at 8 bytes -- jit code expects this
219 size_t N_ithr, N_nthr;
220 size_t coff_max, soff_max;
221 size_t mb_stride_Bc, spat_size, spat_size_loc;
222 size_t S_s, S_tail;
223 size_t is_cblk_tail;
224 acc_data_t chan_size, eps, one;
225 const acc_data_t *scale;
226 const acc_data_t *shift;
227 const acc_data_t *mean, *var;
228 const acc_data_t *diff_scale;
229 const acc_data_t *diff_shift;
230 const void *src, *dst;
231 const void *diff_src, *diff_dst;
232 const acc_data_t *rbuf1, *rbuf2;
233 const uint8_t *ws;
234 barrier::ctx_64_t *barrier;
235 };
236
237 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
238
239 /* cpu specific part */
240 using Vmm = typename utils::conditional3<isa == sse41, Xmm, isa == avx2,
241 Ymm, Zmm>::type;
242 const AddressFrame &vmmword
243 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
244
245 const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen;
246 int vlen_spat_data_
247 = 0; // set by ctor depending on data type (xF16 or FP32);
248
249 const batch_normalization_pd_t *pd_ = nullptr;
250 const jit_bnorm_conf_t *jbp_ = nullptr;
251 bool is_bf16_ = false;
252 bool is_f16_ = false;
253 bool is_avx2_ne_xf16_ = false;
254
255 Reg64 reg_param = abi_param1;
256
257 Reg64 reg_scale = rbx;
258 Reg64 reg_rbuf1 = abi_not_param1;
259 Reg64 reg_rbuf2 = rdx;
260 Reg64 reg_coff_max_fwd_copy = reg_rbuf2;
261
262 Reg64 reg_mean = rbp;
263 Reg64 reg_var = reg_param;
264 Reg64 reg_diff_scale = rax;
265 Reg64 reg_coff_max_bwd_copy = reg_diff_scale;
266 Reg64 reg_shift = reg_rbuf1;
267
268 Reg64 reg_coff = r8;
269 Reg64 reg_coff_max = r9;
270 Reg64 reg_soff = r10;
271 Reg64 reg_soff_max = r11;
272 Reg64 reg_diff_shift = reg_soff_max;
273 Reg64 reg_ctr = r12;
274 Reg64 reg_roff = r13;
275
276 Reg64 reg_mb_stride_Bc = r14;
277 Reg64 reg_soff_nspc = reg_mb_stride_Bc;
278
279 Reg64 reg_src = r15;
280 Reg64 reg_diff_src = reg_rbuf1;
281 Reg64 reg_dst = rsi;
282 Reg64 reg_diff_dst = reg_dst;
283
284 Reg64 reg_tmp_off = reg_roff;
285
286 // Reuse loop counters
287 Reg64 reg_bar = reg_coff;
288 Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
289 Reg64 reg_tmp = reg_ctr;
290
291 // Relu section
292 bool with_relu = false, with_relu_inf_only = false;
293 Reg64 reg_ws = reg_roff;
294 Reg64 reg_tmp_alpha = reg_diff_scale; // required in sse41
295 Label l_relu_mask_avx2;
296 Opmask kstore_mask = Opmask(1);
297
298 // channel tail processing
299 Opmask ktail_mask = Opmask(2);
300
301 // FP32->BF16 emulation
302 bf16_emulation_t *bf16_emu_ {nullptr};
303 Reg64 reg_bf16_tmp = reg_tmp;
304 Zmm bf16_emu_reserved_1 = Zmm(17);
305 Zmm bf16_emu_reserved_2 = Zmm(18);
306 Zmm bf16_emu_reserved_3 = Zmm(19);
307 Zmm bf16_emu_reserved_4 = Zmm(20);
308
309 size_t unroll_blocks;
310 size_t unroll_regs;
311 Vmm vdiff_beta = Vmm(isa == avx512_core ? 21 : 6);
312 Vmm vdiff_gamma = Vmm(isa == avx512_core ? 22 : 7);
313 Vmm vsqrtvar = Vmm(isa == avx512_core ? 23 : 8);
314 Vmm vone = Vmm(isa == avx512_core ? 24 : 9);
315 Vmm vmean = Vmm(isa == avx512_core ? 25 : 10);
316 Vmm vgamma = Vmm(isa == avx512_core ? 26 : 11);
317 Vmm vbeta = Vmm(isa == avx512_core ? 27 : 12);
318 Vmm veps = Vmm(isa == avx512_core ? 28 : 13);
319 Vmm vchan_size = Vmm(isa == avx512_core ? 29 : 14);
320 Vmm vtail_mask = Vmm(isa == avx512_core ? 30 : 15);
321 Vmm vtmp = Vmm(isa == avx512_core ? 31 : 5);
322 Vmm vsrc_aux = vdiff_gamma; // used for xf16 with nspc ON AVX2
323 Vmm vdst_aux = vdiff_gamma; // used for ReLU in AVX2 & sse41
324 Vmm vmask = Vmm(0);
325 Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
326
327 size_t spat_size;
328 size_t chan_data_offt;
329 size_t spat_step;
330 size_t mb_offt;
331 size_t ws_mb_offt;
332
333 enum {
334 stack_off_N_nthr = 0,
335 stack_off_N_ithr = 8,
336 stack_off_src = 16,
337 stack_off_dst = 24,
338 stack_off_diff_src = 32,
339 stack_off_diff_dst = 40,
340 stack_off_diff_scale = 48,
341 stack_off_ws = 56,
342 stack_off_barrier = 64,
343 stack_off_spat_size_loc = 72,
344 stack_off_s_s = 80,
345 stack_off_s_tail = 88,
346 stack_off_is_cblk_tail = 96,
347 stack_off_ws_off_copy = 104,
348 stack_off_shift = 112,
349 stack_off_diff_shift = 120,
350 stack_off_soff_max = 128,
351 stack_off_relu_alpha = 136,
352 stack_size_required = 144,
353 };
354
355 bool is_xf16() { return is_bf16_ || is_f16_; }
356 int bit_shift() { return 5 - is_xf16(); }
357
358 bool use_bf16_emulation() {
359 return is_bf16_ && isa == avx512_core && !mayiuse(avx512_core_bf16);
360 }
361
362 bool stream_store_supported() {
363 // keep original behavior for f32
364 if (!is_xf16()) return true;
365 // TODO: check performance of heuristic for other cases, such as:
366 // blocked layout, pre-avx512_core_amx machines, and f32 datatype.
367 const bool is_applicable = jbp_->is_nspc_ && mayiuse(avx512_core_amx);
368 if (!is_applicable) return false;
369 const size_t l2_size_per_core = platform::get_per_core_cache_size(2);
370 const size_t l3_size_per_core = platform::get_per_core_cache_size(3);
371 const size_t cache_size_per_core = l2_size_per_core + l3_size_per_core;
372 const size_t buffer_count = pd_->is_fwd() ? 2 : 3;
373 const size_t data_size = buffer_count * jbp_->dt_size_ * pd_->MB()
374 * pd_->C() * pd_->D() * pd_->H() * pd_->W();
375 // do not divide by C_nthr for nspc layout
376 const size_t data_size_per_core
377 = data_size / (jbp_->N_nthr_ * jbp_->S_nthr_);
378 return cache_size_per_core < data_size_per_core;
379 }
380
381 bool is_c_padded() const {
382 const memory_desc_wrapper data_d(pd_->src_md());
383 return pd_->C() != data_d.padded_dims()[1];
384 }
385
386 void compute_static_strides() {
387 spat_size = pd_->D() * pd_->W() * pd_->H();
388 chan_data_offt = pd_->C() * sizeof(acc_data_t);
389 spat_step = jbp_->is_nspc_ ? chan_data_offt / (1 + is_xf16())
390 : vlen_spat_data_;
391 mb_offt = spat_step * spat_size;
392 ws_mb_offt = (spat_step / (is_xf16() ? 16 : 32)) * spat_size;
393 }
394
395 void load_common_params() {
396#define PARAM_OFF(x) offsetof(call_params_t, x)
397 mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
398 if (!pd_->is_fwd()) mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
399 mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
400 mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
401 mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
402 shl(reg_coff_max, 2);
403
404 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
405 mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
406
407 uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
408 uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
409 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
410
411 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
412 mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
413 mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
414 mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
415 mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
416 mov(ptr[rsp + stack_off_src], reg_tmp);
417 mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
418 mov(ptr[rsp + stack_off_dst], reg_tmp);
419 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
420 mov(ptr[rsp + stack_off_diff_src], reg_tmp);
421 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
422 mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
423 mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
424 mov(ptr[rsp + stack_off_ws], reg_tmp);
425 mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
426 mov(ptr[rsp + stack_off_barrier], reg_tmp);
427 if (jbp_->is_spatial_thr_) {
428 mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
429 mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
430 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
431 mov(ptr[rsp + stack_off_s_s], reg_tmp);
432 mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
433 mov(ptr[rsp + stack_off_s_tail], reg_tmp);
434 }
435 if (is_c_padded()) {
436 mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
437 mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
438 }
439
440 if (pd_->is_fwd()) {
441 mov(reg_tmp, ptr[reg_param + PARAM_OFF(shift)]);
442 mov(ptr[rsp + stack_off_shift], reg_tmp);
443 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
444 mov(reg_var, reg_tmp);
445 } else {
446 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale)]);
447 mov(ptr[rsp + stack_off_diff_scale], reg_tmp);
448 mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_shift)]);
449 mov(ptr[rsp + stack_off_diff_shift], reg_tmp);
450 mov(reg_tmp, ptr[reg_param + PARAM_OFF(soff_max)]);
451 mov(ptr[rsp + stack_off_soff_max], reg_tmp);
452 mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
453 mov(reg_var, reg_tmp);
454 }
455 if (with_relu_inf_only && pd_->alpha() != 0.f) {
456 mov(reg_tmp, float2int(pd_->alpha()));
457 mov(ptr[rsp + stack_off_relu_alpha], reg_tmp);
458 }
459#undef PARAM_OFF
460 }
461
462 void prepare_tail_mask_avx512_common() {
463 if (!is_c_padded()) return;
464
465 const int tail = pd_->C() % (int)(vlen / sizeof(float));
466 const int mask = (1 << tail) - 1;
467
468 Reg32 regw_tmp = reg_tmp.cvt32();
469 mov(regw_tmp, mask);
470 kmovw(ktail_mask, regw_tmp);
471 }
472
473 void prepare_tail_mask_avx2_common() {
474 if (!is_c_padded()) return;
475
476 const int tail = pd_->C() % (int)(vlen / sizeof(float));
477 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
478 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
479 0, 0, 0, 0, 0, 0, 0};
480
481 mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
482 vmovups(vtail_mask, ptr[reg_tmp]);
483 }
484
485 void prepare_relu() {
486 with_relu = pd_->is_fwd() ? pd_->with_relu_post_op(pd_->is_training())
487 || pd_->fuse_norm_relu()
488 : pd_->fuse_norm_relu();
489 with_relu_inf_only = with_relu && pd_->is_fwd()
490 && !(pd_->fuse_norm_relu() && pd_->is_training());
491
492 vzero = pd_->is_fwd() ? vdiff_beta : vbeta;
493 if (with_relu) {
494 uni_vpxor(vzero, vzero, vzero);
495 if (!pd_->is_fwd() && isa == avx2) prepare_l_relu_mask_avx2();
496 }
497 }
498
499 void prepare_l_relu_mask_avx2() {
500 Label l_mask_after;
501 jmp(l_mask_after);
502 align(32);
503 L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
504 for (int i = 0; i < 8; ++i)
505 dd(1 << i);
506 L(l_mask_after);
507 }
508
509 void fwd_process_relu_avx2(Vmm vdst, int offt) {
510 Reg64 reg_store_mask = reg_diff_scale;
511 Reg64 reg_soff_loc = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
512 shr(reg_soff_loc, bit_shift());
513 vcmpps(vtmp, vzero, vdst, _cmp_lt_os);
514 vmovmskps(reg_store_mask, vtmp);
515 mov(ptr[reg_ws + reg_soff_loc + offt / (1 << bit_shift())],
516 reg_store_mask.cvt8());
517 vblendvps(vdst, vzero, vdst, vtmp);
518 shl(reg_soff_loc, bit_shift());
519 }
520
521 void fwd_process_relu_avx512_common(Vmm vdst, int offt = 0) {
522 Reg64 reg_soff_loc = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
523 shr(reg_soff_loc, bit_shift());
524 vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
525 kmovw(ptr[reg_ws + reg_soff_loc + offt / (1 << bit_shift())],
526 kstore_mask);
527 vblendmps(vdst | kstore_mask, vzero, vdst);
528 shl(reg_soff_loc, bit_shift());
529 }
530
531 void fwd_process_relu_alpha(Vmm vmm_dst) {
532 if (isa == avx512_core)
533 fwd_process_relu_alpha_avx512_common(vmm_dst);
534 else {
535 assert(utils::one_of(isa, avx2, sse41));
536 if (vmm_dst.getIdx() == 0) {
537 uni_vmovups(vdst_aux, vmm_dst);
538 fwd_process_relu_alpha_avx2(vdst_aux);
539 uni_vmovups(Vmm(0), vdst_aux);
540 } else
541 fwd_process_relu_alpha_avx2(vmm_dst);
542 }
543 }
544 void fwd_process_relu_alpha_avx512_common(Vmm vmm_dst) {
545 const Xmm xmm_tmp = Xmm(vtmp.getIdx());
546 vmovq(xmm_tmp, ptr[rsp + stack_off_relu_alpha]);
547 vbroadcastss(vtmp, xmm_tmp);
548 vcmpps(kstore_mask, vzero, vmm_dst, _cmp_lt_os);
549 vmulps(vtmp, vmm_dst, vtmp);
550 vblendmps(vmm_dst | kstore_mask, vtmp, vmm_dst);
551 }
552
553 void fwd_process_relu_alpha_avx2(Vmm vmm_dst) {
554 const Xmm xmm_tmp = Xmm(vtmp.getIdx());
555 uni_vpxor(vmask, vmask, vmask);
556 if (isa == sse41) {
557 mov(reg_tmp_alpha, ptr[rsp + stack_off_relu_alpha]);
558 uni_vmovq(xmm_tmp, reg_tmp_alpha);
559 } else
560 vmovq(xmm_tmp, ptr[rsp + stack_off_relu_alpha]);
561 uni_vbroadcastss(vtmp, xmm_tmp);
562 uni_vcmpps(vmask, vmm_dst, vzero, _cmp_lt_os);
563 uni_vmulps(vtmp, vtmp, vmm_dst);
564 uni_vblendvps(vmm_dst, vmm_dst, vtmp, vmask);
565 }
566
567 void bwd_process_relu_avx2(Vmm vdiff_dst, int offt) {
568 shr(reg_soff, bit_shift());
569 vpbroadcastb(vtmp, ptr[reg_ws + reg_soff + offt / (1 << bit_shift())]);
570 vpand(vtmp, vtmp, ptr[rip + l_relu_mask_avx2]);
571 vpcmpeqd(vtmp, vtmp, ptr[rip + l_relu_mask_avx2]);
572 vblendvps(vdiff_dst, vzero, vdiff_dst, vtmp);
573 shl(reg_soff, bit_shift());
574 }
575
576 void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt = 0) {
577 shr(jbp_->is_nspc_ ? reg_soff_nspc : reg_soff, bit_shift());
578 kmovw(kstore_mask,
579 ptr[reg_ws + (jbp_->is_nspc_ ? reg_soff_nspc : reg_soff)
580 + offt / (1 << bit_shift())]);
581 vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
582 shl(jbp_->is_nspc_ ? reg_soff_nspc : reg_soff, bit_shift());
583 }
584
585 void merge_interleaved_to_plain(
586 const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0) {
587 Ymm ymm_even = Ymm(vmm_even.getIdx());
588 Ymm ymm_odd = Ymm(vmm_odd.getIdx());
589 Ymm ymm_aux0 = Ymm(vmm_aux0.getIdx());
590 Ymm ymm_aux1 = ymm_odd;
591
592 vpunpckldq(ymm_aux0, ymm_even, ymm_odd);
593 vpunpckhdq(ymm_aux1, ymm_even, ymm_odd);
594 vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20);
595 vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31);
596 }
597 void uni_vmovups_spat_data(
598 const Vmm &vmm_even, const Vmm &vmm_odd, const Address &addr) {
599 // load two simd_w data from addr into two registers
600 if (is_bf16_) {
601 // convert bf16 input to f32
602 vcvtneebf162ps(vmm_even, addr);
603 vcvtneobf162ps(vmm_odd, addr);
604 } else if (is_f16_) {
605 vcvtneeph2ps(vmm_even, addr);
606 vcvtneoph2ps(vmm_odd, addr);
607 } else
608 assert(!"unsupported data type!");
609 }
610
611 void uni_vmovups_spat_data(
612 const Operand &dst, const Operand &src, bool is_nt_store = false) {
613 if (dst.isMEM()) {
614 if (is_bf16_) {
615 constexpr bool isAvx2 = isa == avx2;
616 const typename std::conditional<isAvx2, Xmm, Ymm>::type
617 dst_reg {src.getIdx()};
618 const typename std::conditional<isAvx2, Ymm, Zmm>::type
619 src_reg {src.getIdx()};
620
621 // convert f32 output to bf16
622 if (!use_bf16_emulation())
623 vcvtneps2bf16(dst_reg, src_reg,
624 mayiuse(avx512_core) ? Xbyak::EvexEncoding
625 : Xbyak::VexEncoding);
626 else
627 bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);
628
629 // store to memory
630 if (is_nt_store)
631 uni_vmovntps(dst.getAddress(), dst_reg);
632 else
633 uni_vmovups(dst.getAddress(), dst_reg);
634 } else if (is_f16_) {
635 auto src_reg = Vmm(src.getIdx());
636 auto dst_reg =
637 typename vreg_traits<Vmm>::Vmm_lower_t(src.getIdx());
638 if (is_nt_store) {
639 if (mayiuse(avx512_core_fp16))
640 vcvtps2phx(dst_reg, src_reg);
641 else
642 vcvtps2ph(dst_reg, src_reg, _op_mxcsr);
643 uni_vmovntps(dst.getAddress(), dst_reg);
644 } else {
645 vcvtps2ph(dst.getAddress(), src_reg, _op_mxcsr);
646 }
647 } else {
648 if (is_nt_store)
649 uni_vmovntps(dst.getAddress(), Vmm(src.getIdx()));
650 else
651 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
652 }
653 } else {
654 if (is_bf16_) {
655 // convert bf16 input to f32
656 vpmovzxwd(Vmm(dst.getIdx()), src.getAddress());
657 vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10);
658 } else if (is_f16_) {
659 if (mayiuse(avx512_core_fp16))
660 vcvtph2psx(Vmm(dst.getIdx()), src.getAddress());
661 else
662 vcvtph2ps(Vmm(dst.getIdx()), src.getAddress());
663 } else {
664 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
665 }
666 }
667 }
668
669 void uni_vmovups_tail_avx2_common(
670 const Operand &dst, const Operand &src, Label &l_ret) {
671 if (dst.isMEM()) {
672 vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
673 } else {
674 vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
675 }
676 jmp(l_ret);
677 }
678
679 void uni_vmovups_tail_avx512_common(
680 const Operand &dst, const Operand &src, Label &l_ret) {
681 if (dst.isMEM())
682 uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
683 else
684 uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
685
686 jmp(l_ret);
687 }
688
689 void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
690 Label l_no_mask, l_ret;
691
692 if (is_c_padded()) {
693 mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
694 cmp(reg_tmp, 0);
695 jz(l_no_mask);
696
697 lea(reg_tmp, ptr[reg_coff + vlen]);
698 cmp(reg_tmp, reg_coff_max);
699 jl(l_no_mask);
700 assert(isa == avx512_core || isa == avx2);
701 if (isa == avx512_core)
702 uni_vmovups_tail_avx512_common(dst, src, l_ret);
703 else if (isa == avx2)
704 uni_vmovups_tail_avx2_common(dst, src, l_ret);
705 }
706 L(l_no_mask);
707 if (dst.isMEM())
708 uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
709 else
710 uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
711
712 L(l_ret);
713 }
714
715 void barrier() {
716 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
717 mov(reg_bar, ptr[rsp + stack_off_barrier]);
718 simple_barrier::generate(*this, reg_bar, reg_nnthr);
719 }
720
721 Address mean_ptr(size_t offt = 0) {
722 return vmmword[reg_mean + reg_coff + offt];
723 }
724
725 Address var_ptr(size_t offt = 0) {
726 return vmmword[reg_var + reg_coff + offt];
727 }
728
729 Address diff_gamma_ptr(size_t offt = 0) {
730 return vmmword[reg_diff_scale + reg_coff + offt];
731 }
732
733 Address diff_beta_ptr(size_t offt = 0) {
734 return vmmword[reg_diff_shift + reg_coff + offt];
735 }
736
737 Address gamma_ptr(size_t offt = 0) {
738 return vmmword[reg_scale + reg_coff + offt];
739 }
740
741 Address beta_ptr(size_t offt = 0) {
742 return vmmword[reg_shift + reg_coff + offt];
743 }
744
745 template <typename init_t, typename body_t, typename fini_t>
746 void spat_loop(size_t len, size_t blocks, size_t regs, init_t init,
747 body_t body, fini_t fini) {
748 size_t factor = regs * blocks;
749 size_t loop_unroll = len / factor * factor;
750 size_t loop_tail = len - loop_unroll;
751 size_t num_active_regs = (len < regs) ? len : regs;
752 for (size_t i = 0; i < num_active_regs; i++)
753 init(i);
754 if (loop_unroll) {
755 if (jbp_->is_spatial_thr_) {
756 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
757 add(reg_soff, ptr[rsp + stack_off_s_s]);
758 } else {
759 mov(reg_ctr, loop_unroll);
760 }
761 Label label;
762 L(label);
763 {
764 for (size_t i = 0; i < factor; i++) {
765 size_t base_reg = i % regs;
766 body(base_reg, i);
767 }
768 add(reg_soff, factor * spat_step);
769 sub(reg_ctr, factor);
770 jnz(label);
771 }
772 if (jbp_->is_spatial_thr_) {
773 add(reg_soff, ptr[rsp + stack_off_s_tail]);
774 }
775 }
776
777 for (size_t i = 0; i < loop_tail; i++) {
778 size_t base_reg = i % regs;
779 body(base_reg, i);
780 }
781 if (loop_tail) add(reg_soff, loop_tail * spat_step);
782
783 for (size_t i = 0; i < num_active_regs; i++)
784 fini(i);
785 }
786
787 void mean_channels() {
788 Label ch_label;
789 L(ch_label);
790 {
791 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
792 spat_loop(
793 spat_size, unroll_blocks, unroll_regs,
794 [=](size_t base_reg) {
795 Vmm v = Vmm(base_reg * 2);
796 if (base_reg) uni_vpxor(v, v, v);
797 },
798 [=](size_t base_reg, size_t i) {
799 Vmm v0 = Vmm(base_reg * 2 + 0);
800 Vmm v1 = Vmm(base_reg * 2 + 1);
801 size_t offt = i * vlen_spat_data_;
802 uni_vmovups_spat_data(
803 v1, vmmword[reg_src + reg_soff + offt]);
804 uni_vaddps(v0, v0, v1);
805 },
806 [=](size_t base_reg) {
807 Vmm b = Vmm(0);
808 Vmm v = Vmm(base_reg * 2);
809 if (base_reg) uni_vaddps(b, b, v);
810 });
811 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
812
813 add(reg_coff, vlen);
814 cmp(reg_coff, reg_coff_max);
815 jl(ch_label);
816 }
817 }
818
819 void mean_variance_nspc(
820 const int num_ch_blks, int num_spat_pts, bool compute_mean) {
821
822 auto mean_compute_avx2_ne_xf16 = [=](int num_ch_blks,
823 int num_spat_pts) {
824 for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) {
825 for (int ch_idx = 0; ch_idx < num_ch_blks; ch_idx += 2) {
826 const int offt = ch_idx * vlen_spat_data_;
827 const bool is_ch_blks_tail = num_ch_blks - ch_idx < 2;
828 const Vmm vsrc_even = vtmp;
829 const Vmm vsrc_odd = vsrc_aux;
830 if (is_ch_blks_tail)
831 uni_vmovups_spat_data(vsrc_even,
832 vmmword[reg_src + reg_soff_nspc + offt]);
833 else
834 uni_vmovups_spat_data(vsrc_even, vsrc_odd,
835 vmmword[reg_src + reg_soff_nspc + offt]);
836
837 uni_vaddps(Vmm(ch_idx), Vmm(ch_idx), vsrc_even);
838 if (!is_ch_blks_tail)
839 uni_vaddps(Vmm(ch_idx + 1), Vmm(ch_idx + 1), vsrc_odd);
840 }
841 add(reg_soff_nspc, spat_step);
842 }
843 };
844
845 auto variance_compute_avx2_ne_xf16 = [=](int num_ch_blks,
846 int num_spat_pts) {
847 for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) {
848 for (int ch_idx = 0; ch_idx < num_ch_blks; ch_idx += 2) {
849 const int offt = ch_idx * vlen_spat_data_;
850 const bool is_ch_blks_tail = num_ch_blks - ch_idx < 2;
851 const Vmm vsrc_even = vtmp;
852 const Vmm vsrc_odd = vsrc_aux;
853 const Vmm vmean_ch_even = Vmm(ch_idx + num_ch_blks);
854 const Vmm vmean_ch_odd = Vmm(ch_idx + 1 + num_ch_blks);
855 if (is_ch_blks_tail)
856 uni_vmovups_spat_data(vsrc_even,
857 vmmword[reg_src + reg_soff_nspc + offt]);
858 else
859 uni_vmovups_spat_data(vsrc_even, vsrc_odd,
860 vmmword[reg_src + reg_soff_nspc + offt]);
861 uni_vsubps(vsrc_even, vsrc_even, vmean_ch_even);
862 uni_vfmadd231ps(Vmm(ch_idx), vsrc_even, vsrc_even);
863 if (!is_ch_blks_tail) {
864 uni_vsubps(vsrc_odd, vsrc_odd, vmean_ch_odd);
865 uni_vfmadd231ps(Vmm(ch_idx + 1), vsrc_odd, vsrc_odd);
866 }
867 }
868 add(reg_soff_nspc, spat_step);
869 }
870 };
871
872 auto mean_compute = [=](int num_ch_blks, int num_spat_pts) {
873 for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) {
874 for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) {
875 const int offt = ch_idx * vlen_spat_data_;
876 const Vmm vsrc = vtmp;
877 uni_vmovups_spat_data(
878 vsrc, vmmword[reg_src + reg_soff_nspc + offt]);
879 uni_vaddps(Vmm(ch_idx), Vmm(ch_idx), vsrc);
880 }
881 add(reg_soff_nspc, spat_step);
882 }
883 };
884
885 auto variance_compute = [=](int num_ch_blks, int num_spat_pts) {
886 for (int spat_pt = 0; spat_pt < num_spat_pts; ++spat_pt) {
887 for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) {
888 const int offt = ch_idx * vlen_spat_data_;
889 const Vmm vsrc = vtmp;
890 const Vmm vmean_ch = Vmm(ch_idx + num_ch_blks);
891 uni_vmovups_spat_data(
892 vsrc, vmmword[reg_src + reg_soff_nspc + offt]);
893 uni_vsubps(vsrc, vsrc, vmean_ch);
894 uni_vfmadd231ps(Vmm(ch_idx), vsrc, vsrc);
895 }
896 add(reg_soff_nspc, spat_step);
897 }
898 };
899
900 for (int idx = 0; idx < num_ch_blks; ++idx) {
901 const int coff = idx * vlen;
902 uni_vmovups(Vmm(idx), vmmword[reg_rbuf1 + reg_coff + coff]);
903 if (!compute_mean) {
904 // pre-load mean to avoid extra data movement during variance
905 const Vmm vmean_ch = Vmm(idx + num_ch_blks);
906 uni_vmovups_maybe_tail(vmean_ch, mean_ptr(coff));
907 }
908 }
909
910 xor_(reg_soff_nspc, reg_soff_nspc);
911
912 if (jbp_->is_spatial_thr_) {
913 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
914 add(reg_soff_nspc, ptr[rsp + stack_off_s_s]);
915 // TODO: need a better heuristic for num_spat_pts
916 num_spat_pts = 1;
917 } else {
918 mov(reg_ctr, spat_size);
919 num_spat_pts = nstl::min((size_t)num_spat_pts, spat_size);
920 // TODO: unroll by spatial
921 if (spat_size % num_spat_pts != 0) num_spat_pts = 1;
922 }
923
924 Label spatial;
925 L(spatial);
926 {
927 if (is_avx2_ne_xf16_)
928 compute_mean
929 ? mean_compute_avx2_ne_xf16(num_ch_blks, num_spat_pts)
930 : variance_compute_avx2_ne_xf16(
931 num_ch_blks, num_spat_pts);
932 else
933 compute_mean ? mean_compute(num_ch_blks, num_spat_pts)
934 : variance_compute(num_ch_blks, num_spat_pts);
935 sub(reg_ctr, num_spat_pts);
936 jnz(spatial, T_NEAR);
937 }
938
939 for (int idx = 0; idx < num_ch_blks; ++idx) {
940 const int coff = idx * vlen;
941 uni_vmovups(vmmword[reg_rbuf1 + reg_coff + coff], Vmm(idx));
942 }
943 }
944
945 void forward_channels_nspc_compute(const int num_ch_blks) {
946 auto compute = [=](bool stream_store_allowed) {
947 // Overwritten during mean and variance computation
948 uni_vpxor(vzero, vzero, vzero);
949
950 xor_(reg_soff_nspc, reg_soff_nspc);
951
952 if (jbp_->is_spatial_thr_) {
953 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
954 add(reg_soff_nspc, ptr[rsp + stack_off_s_s]);
955 } else {
956 mov(reg_ctr, spat_size);
957 }
958
959 // TODO: spatial blocking
960 const int num_spat_pts = 1;
961
962 // pre-compute scale for each channel to avoid costly div and sqrt
963 // merge variances in interleaved to plain layout if needed
964 for (int idx = 0; idx < num_ch_blks; idx += 2) {
965 const int coff_base = idx * vlen;
966 const bool is_ch_blks_tail = num_ch_blks - idx < 2;
967 const Vmm vvar_even = Vmm(idx);
968 const Vmm vvar_odd = Vmm(idx + 1);
969 if (!is_ch_blks_tail) {
970 uni_vmovups_maybe_tail(vvar_even, var_ptr(coff_base));
971 uni_vmovups_maybe_tail(vvar_odd, var_ptr(coff_base + vlen));
972 if (is_avx2_ne_xf16_ && !pd_->stats_is_src())
973 merge_interleaved_to_plain(vvar_even, vvar_odd, vtmp);
974 } else
975 uni_vmovups_maybe_tail(vvar_even, var_ptr(coff_base));
976
977 for (int i_odd = 0; i_odd < 2 && idx + i_odd < num_ch_blks;
978 ++i_odd) {
979 const int coff = coff_base + i_odd * vlen;
980 const Vmm vscale = Vmm(idx + i_odd + num_ch_blks);
981 const Vmm vvar = i_odd ? vvar_odd : vvar_even;
982 uni_vmovups(vsqrtvar, vvar);
983 uni_vaddps(vsqrtvar, vsqrtvar, veps);
984 uni_vsqrtps(vsqrtvar, vsqrtvar);
985
986 if (pd_->use_scale()) {
987 uni_vmovups_maybe_tail(vgamma, gamma_ptr(coff));
988 uni_vdivps(vscale, vgamma, vsqrtvar, vtmp);
989 } else {
990 uni_vdivps(vscale, vone, vsqrtvar, vtmp);
991 }
992 }
993 }
994
995 Label spatial;
996 L(spatial);
997 {
998 if (is_avx2_ne_xf16_) {
999 for (int idx = 0; idx < num_ch_blks; idx += 2) {
1000 const int offt = idx * vlen_spat_data_;
1001 const int coff = idx * vlen;
1002 const bool is_ch_blks_tail = num_ch_blks - idx < 2;
1003 Vmm vdata_even = Vmm(idx);
1004 Vmm vdata_odd = Vmm(idx + 1);
1005 if (is_ch_blks_tail) {
1006 uni_vmovups_spat_data(vdata_even,
1007 vmmword[reg_src + reg_soff_nspc + offt]);
1008 if (!pd_->stats_is_src())
1009 uni_vsubps(
1010 vdata_even, vdata_even, mean_ptr(coff));
1011 } else {
1012 uni_vmovups_spat_data(vdata_even, vdata_odd,
1013 vmmword[reg_src + reg_soff_nspc + offt]);
1014 // apply mean in interleave to data in interleave
1015 // before merge them to plain layout when needed
1016 if (!pd_->stats_is_src()) {
1017 uni_vsubps(
1018 vdata_even, vdata_even, mean_ptr(coff));
1019 uni_vsubps(vdata_odd, vdata_odd,
1020 mean_ptr(coff + vlen));
1021 }
1022 merge_interleaved_to_plain(
1023 vdata_even, vdata_odd, vtmp);
1024 }
1025 }
1026 }
1027
1028 for (int idx = 0; idx < num_ch_blks; ++idx) {
1029 const int coff = idx * vlen;
1030 const int offt = idx * vlen_spat_data_;
1031 const Vmm vdata = Vmm(idx);
1032 const Vmm vscale = Vmm(idx + num_ch_blks);
1033 uni_vmovups_maybe_tail(vmean, mean_ptr(coff));
1034
1035 if (pd_->use_shift()) {
1036 uni_vmovups_maybe_tail(vbeta, beta_ptr(coff));
1037 }
1038
1039 if (!is_avx2_ne_xf16_)
1040 uni_vmovups_spat_data(
1041 vdata, vmmword[reg_src + reg_soff_nspc + offt]);
1042 if (IMPLICATION(is_avx2_ne_xf16_, pd_->stats_is_src()))
1043 uni_vsubps(vdata, vdata, vmean);
1044
1045 if (pd_->use_shift()) {
1046 // --flags=S,CH,H
1047 uni_vfmadd213ps(vdata, vscale, vbeta);
1048 } else {
1049 // --flags=,C
1050 uni_vmulps(vdata, vdata, vscale);
1051 }
1052
1053 if (with_relu_inf_only) { // --attr=post_ops='relu'
1054 if (pd_->alpha() != 0.f)
1055 fwd_process_relu_alpha(vdata);
1056 else
1057 uni_vmaxps(vdata, vdata, vzero);
1058 } else if (with_relu) { // --flags=R
1059 if (isa == avx512_core)
1060 fwd_process_relu_avx512_common(vdata, offt);
1061 else if (isa == avx2)
1062 fwd_process_relu_avx2(vdata, offt);
1063 else
1064 assert(false);
1065 }
1066 uni_vmovups_spat_data(
1067 vmmword[reg_dst + reg_soff_nspc + offt], vdata,
1068 stream_store_allowed);
1069 }
1070 add(reg_soff_nspc, spat_step);
1071 sub(reg_ctr, num_spat_pts);
1072 jnz(spatial, T_NEAR);
1073 }
1074 };
1075
1076 if (stream_store_supported()) {
1077 Label normal_store, end_store;
1078 test(reg_dst, vlen_spat_data_ - 1);
1079 jnz(normal_store, T_NEAR);
1080 compute(true);
1081 jmp(end_store, T_NEAR);
1082 L(normal_store);
1083 { compute(false); }
1084 L(end_store);
1085 } else {
1086 compute(false); // disabled for bf16 when data fits in cache
1087 }
1088 }
1089
1090 void compute_mean_variance_nspc(bool compute_mean = true) {
1091 xor_(reg_coff, reg_coff);
1092 mov(reg_coff_max_fwd_copy, reg_coff_max);
1093
1094 Label ch_unroll_label[5];
1095 const int max_ch_unroll = isa == avx512_core ? 4 : 2;
1096
1097 // TODO: Spatial and channel unrolling decisions should be made during
1098 // initialization depending on the problem size
1099 for (int ch_idx = max_ch_unroll, sp_idx = 1; ch_idx > 0;
1100 --ch_idx, ++sp_idx) {
1101 L(ch_unroll_label[ch_idx]);
1102 {
1103 const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1
1104 cmp(reg_coff_max, vlen * ch_blk_size);
1105 jl(ch_unroll_label[ch_idx - 1], T_NEAR);
1106
1107 const int spat_blk_size = (1 << sp_idx);
1108 mean_variance_nspc(ch_blk_size, spat_blk_size, compute_mean);
1109
1110 add(reg_src, vlen_spat_data_ * ch_blk_size);
1111 add(reg_coff, vlen * ch_blk_size);
1112
1113 sub(reg_coff_max, vlen * ch_blk_size);
1114 jmp(ch_unroll_label[ch_idx], T_NEAR);
1115 }
1116 }
1117 L(ch_unroll_label[0]);
1118
1119 // comeback
1120 mov(reg_coff_max, reg_coff_max_fwd_copy);
1121
1122 if (is_xf16()) shr(reg_coff_max, 1);
1123 sub(reg_src, reg_coff_max);
1124 if (is_xf16()) shl(reg_coff_max, 1);
1125 }
1126
1127 void var_channels() {
1128 Label ch_label;
1129 L(ch_label);
1130 {
1131 uni_vmovups_maybe_tail(vmean, mean_ptr());
1132 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
1133 spat_loop(
1134 spat_size, unroll_blocks, unroll_regs,
1135 [=](size_t base_reg) {
1136 Vmm v = Vmm(base_reg * 3);
1137 if (base_reg > 0) uni_vpxor(v, v, v);
1138 },
1139 [=](size_t base_reg, size_t i) {
1140 Vmm v = Vmm(3 * base_reg);
1141 Vmm vtmp0 = Vmm(3 * base_reg + 1);
1142 Vmm vtmp1 = Vmm(3 * base_reg + 2);
1143 size_t offt = i * vlen_spat_data_;
1144 uni_vmovups_spat_data(
1145 vtmp0, vmmword[reg_src + reg_soff + offt]);
1146 if (isa == sse41) {
1147 movups(vtmp1, vmean);
1148 subps(vtmp1, vtmp0);
1149 } else {
1150 vsubps(vtmp1, vmean, vtmp0);
1151 }
1152 uni_vfmadd231ps(v, vtmp1, vtmp1);
1153 },
1154 [=](size_t base_reg) {
1155 Vmm b = Vmm(0);
1156 Vmm v = Vmm(base_reg * 3);
1157 if (base_reg) uni_vaddps(b, b, v);
1158 });
1159 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
1160 add(reg_coff, vlen);
1161 cmp(reg_coff, reg_coff_max);
1162 jl(ch_label);
1163 }
1164 }
1165
1166 void compute_mean_variance() {
1167 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1168 xor_(reg_coff, reg_coff);
1169 Label zero_rbuf;
1170 L(zero_rbuf);
1171 {
1172 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
1173 add(reg_coff, isa == sse41 ? vlen / 2 : vlen);
1174 cmp(reg_coff, reg_coff_max);
1175 jne(zero_rbuf);
1176 }
1177
1178 mov(reg_src, ptr[rsp + stack_off_src]);
1179
1180 xor_(reg_soff, reg_soff);
1181 Label mean_spatial;
1182 L(mean_spatial);
1183 {
1184 xor_(reg_coff, reg_coff);
1185
1186 if (isa == sse41) mov(reg_tmp_off, reg_soff);
1187
1188 jbp_->is_nspc_ ? compute_mean_variance_nspc() : mean_channels();
1189
1190 if (isa == sse41) {
1191 mov(reg_soff, reg_tmp_off);
1192 add(reg_src, vlen / 2);
1193 mov(reg_coff, vlen / 2);
1194
1195 mean_channels();
1196
1197 sub(reg_src, vlen / 2);
1198 }
1199
1200 // Process next image
1201 if (jbp_->is_nspc_) {
1202 // Can use static offset since we comeback after spatial loop
1203 add(reg_src, mb_offt);
1204 add(reg_soff, mb_offt);
1205 } else {
1206 add(reg_soff, reg_mb_stride_Bc);
1207 }
1208
1209 cmp(reg_soff, reg_soff_max);
1210 jl(mean_spatial);
1211 }
1212
1213 if (jbp_->is_nspc_) mov(reg_src, ptr[rsp + stack_off_src]); // comeback
1214
1215 Label no_mean_reduction;
1216 barrier();
1217 {
1218 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
1219 cmp(reg_tmp, 0);
1220 jne(no_mean_reduction);
1221 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
1222 xor_(reg_coff, reg_coff);
1223 Label mean_reduction_channels;
1224 L(mean_reduction_channels);
1225 {
1226 mov(reg_roff, reg_coff);
1227 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1228 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
1229 mov(reg_ctr, reg_nnthr);
1230 Label mean_reduction_thrs;
1231 L(mean_reduction_thrs);
1232 {
1233 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
1234 uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
1235 add(reg_roff, reg_coff_max);
1236 sub(reg_ctr, 1);
1237 jnz(mean_reduction_thrs);
1238 }
1239 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
1240 uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
1241
1242 add(reg_coff, isa == sse41 ? vlen / 2 : vlen);
1243
1244 cmp(reg_coff, reg_coff_max);
1245 jl(mean_reduction_channels);
1246 }
1247 }
1248 L(no_mean_reduction);
1249 barrier();
1250
1251 xor_(reg_soff, reg_soff);
1252 Label var_spatial;
1253 L(var_spatial);
1254 {
1255 xor_(reg_coff, reg_coff);
1256
1257 if (isa == sse41) mov(reg_tmp_off, reg_soff);
1258
1259 jbp_->is_nspc_ ? compute_mean_variance_nspc(false) : var_channels();
1260
1261 if (isa == sse41) {
1262 mov(reg_soff, reg_tmp_off);
1263 add(reg_src, vlen / 2);
1264 mov(reg_coff, vlen / 2);
1265
1266 var_channels();
1267
1268 sub(reg_src, vlen / 2);
1269 }
1270
1271 // Process next image
1272 if (jbp_->is_nspc_) {
1273 // Can use static offset since we comeback after spatial loop
1274 add(reg_src, mb_offt);
1275 add(reg_soff, mb_offt);
1276 } else {
1277 add(reg_soff, reg_mb_stride_Bc);
1278 }
1279
1280 cmp(reg_soff, reg_soff_max);
1281 jl(var_spatial);
1282 }
1283
1284 if (jbp_->is_nspc_) mov(reg_src, ptr[rsp + stack_off_src]); // comeback
1285
1286 Label no_var_reduction;
1287 barrier();
1288 {
1289 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
1290 cmp(reg_tmp, 0);
1291 jne(no_var_reduction);
1292
1293 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
1294 xor_(reg_coff, reg_coff);
1295 Label var_reduction_channels;
1296 L(var_reduction_channels);
1297 {
1298 mov(reg_roff, reg_coff);
1299 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
1300 mov(reg_ctr, reg_nnthr);
1301 Label var_reduction_thrs;
1302 L(var_reduction_thrs);
1303 { // TODO: unroll (?)
1304 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
1305 add(reg_roff, reg_coff_max);
1306 sub(reg_ctr, 1);
1307 jnz(var_reduction_thrs);
1308 }
1309 uni_vdivps(Vmm(1), Vmm(1), vchan_size);
1310 uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
1311 add(reg_coff, isa == sse41 ? vlen / 2 : vlen);
1312
1313 cmp(reg_coff, reg_coff_max);
1314 jne(var_reduction_channels);
1315 }
1316 }
1317 L(no_var_reduction);
1318 barrier();
1319 }
1320
1321 void forward_channels() {
1322 Label ch_label;
1323 L(ch_label);
1324 {
1325 uni_vmovups_maybe_tail(vmean, mean_ptr());
1326 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
1327 uni_vaddps(vsqrtvar, vsqrtvar, veps);
1328 uni_vsqrtps(vsqrtvar, vsqrtvar);
1329
1330 if (pd_->use_scale()) {
1331 uni_vmovups_maybe_tail(vgamma, gamma_ptr());
1332 }
1333 if (pd_->use_shift()) { uni_vmovups_maybe_tail(vbeta, beta_ptr()); }
1334
1335 Vmm vscale = (pd_->use_scale()) ? vgamma : vone;
1336 Vmm vdiv = (pd_->use_scale()) ? vgamma : vsqrtvar;
1337
1338 if (isa == sse41) {
1339 movups(vtmp, vscale);
1340 divps(vtmp, vsqrtvar);
1341 movups(vdiv, vtmp);
1342 } else {
1343 vdivps(vdiv, vscale, vsqrtvar);
1344 }
1345
1346 const auto spat_loop_init_fin
1347 = [](size_t base_reg) { UNUSED(base_reg); };
1348
1349 const auto spat_loop_body = [=](size_t base_reg, size_t i,
1350 bool stream_store_allowed) {
1351 const Vmm v = Vmm(base_reg);
1352 const size_t offt = i * vlen_spat_data_;
1353 uni_vmovups_spat_data(v, vmmword[reg_src + reg_soff + offt]);
1354 uni_vsubps(v, v, vmean);
1355 if ((pd_->use_scale() && pd_->use_shift())) {
1356 // --flags=CH
1357 uni_vfmadd213ps(v, vgamma, vbeta);
1358 } else if (pd_->use_scale()) {
1359 // --flags=C
1360 uni_vmulps(v, v, vgamma);
1361 } else if (pd_->use_shift()) {
1362 // --flags=H
1363 uni_vfmadd213ps(v, vsqrtvar, vbeta);
1364 } else {
1365 uni_vmulps(v, v, vsqrtvar);
1366 }
1367 if (with_relu_inf_only) { // --attr=post_ops='relu'
1368 if (pd_->alpha() != 0.f) {
1369 fwd_process_relu_alpha(v);
1370 } else
1371 uni_vmaxps(v, v, vzero);
1372 } else if (with_relu) { // --flags=R
1373 if (isa == avx512_core)
1374 fwd_process_relu_avx512_common(v, offt);
1375 else
1376 fwd_process_relu_avx2(v, offt);
1377 }
1378 if (stream_store_allowed) {
1379 uni_vmovntps(vmmword[reg_dst + reg_soff + offt], v);
1380 } else {
1381 uni_vmovups_spat_data(
1382 vmmword[reg_dst + reg_soff + offt], v);
1383 }
1384 };
1385
1386 const auto compute = [=](bool stream_store_allowed) {
1387 using namespace std::placeholders;
1388 spat_loop(spat_size, unroll_blocks, unroll_regs,
1389 spat_loop_init_fin,
1390 std::bind(spat_loop_body, _1, _2, stream_store_allowed),
1391 spat_loop_init_fin);
1392 };
1393
1394 if (stream_store_supported()) {
1395 Label normal_store, end_store;
1396 test(reg_dst, vlen - 1);
1397 jnz(normal_store, T_NEAR);
1398 compute(true);
1399 jmp(end_store, T_NEAR);
1400 L(normal_store);
1401 { compute(false); }
1402 L(end_store);
1403 } else {
1404 compute(false); // no NT store for BF16
1405 }
1406
1407 add(reg_coff, vlen);
1408 cmp(reg_coff, reg_coff_max);
1409 jl(ch_label);
1410 }
1411 }
1412
1413 void forward_channels_nspc() {
1414 xor_(reg_coff, reg_coff);
1415 mov(reg_coff_max_fwd_copy, reg_coff_max);
1416
1417 Label ch_unroll_label[5];
1418 const int max_ch_unroll
1419 = isa == avx512_core ? 4 - use_bf16_emulation() : 2;
1420
1421 // TODO: Spatial and channel unrolling decisions should be made during
1422 // initialization depending on the problem size
1423 for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) {
1424 L(ch_unroll_label[ch_idx]);
1425 {
1426 const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1
1427 cmp(reg_coff_max, vlen * ch_blk_size);
1428 jl(ch_unroll_label[ch_idx - 1], T_NEAR);
1429
1430 forward_channels_nspc_compute(ch_blk_size);
1431
1432 add(reg_src, vlen_spat_data_ * ch_blk_size);
1433 add(reg_dst, vlen_spat_data_ * ch_blk_size);
1434
1435 // advance mean_ptr() and var_ptr()
1436 add(reg_coff, vlen * ch_blk_size);
1437
1438 add(reg_ws, (vlen / 32) * ch_blk_size);
1439
1440 sub(reg_coff_max, vlen * ch_blk_size);
1441 jmp(ch_unroll_label[ch_idx], T_NEAR);
1442 }
1443 }
1444 L(ch_unroll_label[0]);
1445
1446 // comeback
1447 mov(reg_coff_max, reg_coff_max_fwd_copy);
1448
1449 if (is_xf16()) shr(reg_coff_max, 1);
1450 sub(reg_src, reg_coff_max);
1451 sub(reg_dst, reg_coff_max);
1452 if (is_xf16()) shl(reg_coff_max, 1);
1453
1454 shr(reg_coff_max, 5);
1455 sub(reg_ws, reg_coff_max);
1456 shl(reg_coff_max, 5);
1457 }
1458
1459 void forward() {
1460 mov(reg_src, ptr[rsp + stack_off_src]);
1461 mov(reg_dst, ptr[rsp + stack_off_dst]);
1462 mov(reg_ws, ptr[rsp + stack_off_ws]);
1463 mov(reg_shift, ptr[rsp + stack_off_shift]);
1464
1465 xor_(reg_soff, reg_soff);
1466 Label dst_spatial;
1467 L(dst_spatial);
1468 {
1469 xor_(reg_coff, reg_coff);
1470 if (isa == sse41) mov(reg_tmp_off, reg_soff);
1471
1472 jbp_->is_nspc_ ? forward_channels_nspc() : forward_channels();
1473
1474 if (isa == sse41) {
1475 mov(reg_soff, reg_tmp_off);
1476 add(reg_src, vlen / 2);
1477 add(reg_dst, vlen / 2);
1478 mov(reg_coff, vlen / 2);
1479
1480 forward_channels();
1481
1482 sub(reg_src, vlen / 2);
1483 sub(reg_dst, vlen / 2);
1484 }
1485
1486 // Process next image
1487 if (jbp_->is_nspc_) {
1488 // Can use static offset since we comeback after spatial loop
1489 add(reg_src, mb_offt);
1490 add(reg_dst, mb_offt);
1491 add(reg_soff, mb_offt);
1492 add(reg_ws, ws_mb_offt);
1493 } else {
1494 add(reg_soff, reg_mb_stride_Bc);
1495 }
1496
1497 cmp(reg_soff, reg_soff_max);
1498 jl(dst_spatial);
1499 }
1500
1501 if (jbp_->is_nspc_) {
1502 // comeback
1503 mov(reg_src, ptr[rsp + stack_off_src]);
1504 mov(reg_dst, ptr[rsp + stack_off_dst]);
1505 mov(reg_ws, ptr[rsp + stack_off_ws]);
1506 }
1507 }
1508
1509 void backward_sh_channels() {
1510 Label sh_channels;
1511 L(sh_channels);
1512 {
1513 uni_vmovups_maybe_tail(vmean, mean_ptr());
1514 uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
1515 uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
1516 spat_loop(
1517 spat_size, 1, 1,
1518 [=](size_t base_reg) {
1519 if (base_reg > 0) {
1520 for (int i = 0; i < 2; i++) {
1521 Vmm v(base_reg * 5 + i);
1522 uni_vpxor(v, v, v);
1523 }
1524 }
1525 },
1526 [=](size_t base_reg, size_t i) {
1527 // TODO: use single set of tmp regs and let ROB handle the rest
1528 Vmm o0 = Vmm(base_reg * 5 + 0);
1529 Vmm o1 = Vmm(base_reg * 5 + 1);
1530 Vmm t1 = Vmm(base_reg * 5 + 2);
1531 Vmm t2 = Vmm(base_reg * 5 + 3);
1532 Vmm t3 = Vmm(base_reg * 5 + 4);
1533 size_t offt = i * vlen_spat_data_;
1534 uni_vmovups_spat_data(
1535 t1, vmmword[reg_src + reg_soff + offt]);
1536 uni_vmovups_spat_data(
1537 t2, vmmword[reg_diff_dst + reg_soff + offt]);
1538 if (with_relu) {
1539 if (isa == avx512_core)
1540 bwd_process_relu_avx512_common(t2, offt);
1541 else if (isa == avx2)
1542 bwd_process_relu_avx2(t2, offt);
1543 else
1544 assert(false);
1545 }
1546 uni_vsubps(t3, vmean, t1, t3);
1547 if (isa == sse41) {
1548 mulps(t3, t2);
1549 subps(o0, t3);
1550 } else {
1551 vfnmadd231ps(o0, t3, t2);
1552 }
1553 uni_vaddps(o1, o1, t2);
1554 },
1555 [=](size_t base_reg) {
1556 Vmm b0 = Vmm(0);
1557 Vmm b1 = Vmm(1);
1558 if (base_reg) {
1559 uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
1560 uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
1561 }
1562 });
1563 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
1564 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
1565 add(reg_coff, vlen);
1566 cmp(reg_coff, reg_coff_max);
1567 jl(sh_channels);
1568 }
1569 }
1570
1571 void backward_sh_channels_nspc_compute(const int num_ch_blks) {
1572 for (int idx = 0; idx < num_ch_blks; ++idx) {
1573 const int offt = idx * vlen;
1574 const Vmm vdiff_gamma_ch = Vmm(idx);
1575 const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks);
1576 uni_vmovups(vdiff_gamma_ch, vmmword[reg_rbuf1 + reg_coff + offt]);
1577 uni_vmovups(vdiff_beta_ch, vmmword[reg_rbuf2 + reg_coff + offt]);
1578 }
1579
1580 xor_(reg_soff_nspc, reg_soff_nspc);
1581
1582 if (jbp_->is_spatial_thr_) {
1583 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
1584 add(reg_soff_nspc, ptr[rsp + stack_off_s_s]);
1585 } else {
1586 mov(reg_ctr, spat_size);
1587 }
1588
1589 // TODO: spatial blocking
1590 const int num_spat_pts = 1;
1591
1592 Label spatial;
1593 L(spatial);
1594 {
1595 for (int ch_idx = 0; ch_idx < num_ch_blks; ++ch_idx) {
1596 const int coff = ch_idx * vlen;
1597 const int offt = ch_idx * vlen_spat_data_;
1598 const Vmm vdiff_gamma_ch = Vmm(ch_idx);
1599 const Vmm vdiff_beta_ch = Vmm(ch_idx + num_ch_blks);
1600 // vdiff_beta and vdiff_gamma are free registers for nspc
1601 const Vmm vsrc = vdiff_gamma;
1602 const Vmm vdiff_dst = vdiff_beta;
1603 uni_vmovups_maybe_tail(vmean, mean_ptr(coff));
1604
1605 uni_vmovups_spat_data(
1606 vsrc, vmmword[reg_src + reg_soff_nspc + offt]);
1607 uni_vmovups_spat_data(vdiff_dst,
1608 vmmword[reg_diff_dst + reg_soff_nspc + offt]);
1609
1610 if (with_relu) {
1611 if (isa == avx512_core)
1612 bwd_process_relu_avx512_common(vdiff_dst, offt);
1613 else
1614 assert(false);
1615 }
1616
1617 uni_vsubps(vsrc, vsrc, vmean);
1618 uni_vfmadd231ps(vdiff_gamma_ch, vsrc, vdiff_dst);
1619 uni_vaddps(vdiff_beta_ch, vdiff_beta_ch, vdiff_dst);
1620 }
1621 add(reg_soff_nspc, spat_step);
1622 sub(reg_ctr, num_spat_pts);
1623 jnz(spatial, T_NEAR);
1624 }
1625
1626 for (int idx = 0; idx < num_ch_blks; ++idx) {
1627 const Vmm vdiff_gamma_ch = Vmm(idx);
1628 const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks);
1629 const int offt = idx * vlen;
1630 uni_vmovups(vmmword[reg_rbuf1 + reg_coff + offt], vdiff_gamma_ch);
1631 uni_vmovups(vmmword[reg_rbuf2 + reg_coff + offt], vdiff_beta_ch);
1632 }
1633 }
1634
1635 void backward_sh_channels_nspc() {
1636 xor_(reg_coff, reg_coff);
1637 mov(reg_coff_max_bwd_copy, reg_coff_max);
1638
1639 Label ch_unroll_label[5];
1640 const int max_ch_unroll = 4;
1641
1642 // TODO: Spatial and channel unrolling decisions should be made during
1643 // initialization depending on the problem size
1644 for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) {
1645 L(ch_unroll_label[ch_idx]);
1646 {
1647 const int ch_blk_size = (1 << (ch_idx - 1)); // 8, 4, 2, 1
1648 cmp(reg_coff_max, vlen * ch_blk_size);
1649 jl(ch_unroll_label[ch_idx - 1], T_NEAR);
1650
1651 backward_sh_channels_nspc_compute(ch_blk_size);
1652
1653 add(reg_src, vlen_spat_data_ * ch_blk_size);
1654 add(reg_diff_dst, vlen_spat_data_ * ch_blk_size);
1655
1656 // advance mean_ptr() and var_ptr()
1657 add(reg_coff, vlen * ch_blk_size);
1658
1659 add(reg_ws, 2 * ch_blk_size);
1660
1661 sub(reg_coff_max, vlen * ch_blk_size);
1662 jmp(ch_unroll_label[ch_idx], T_NEAR);
1663 }
1664 }
1665 L(ch_unroll_label[0]);
1666
1667 // comeback
1668 mov(reg_coff_max, reg_coff_max_bwd_copy);
1669 mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]);
1670
1671 if (is_xf16()) shr(reg_coff_max, 1);
1672 sub(reg_src, reg_coff_max);
1673 sub(reg_diff_dst, reg_coff_max);
1674 if (is_xf16()) shl(reg_coff_max, 1);
1675
1676 if (with_relu) {
1677 shr(reg_coff_max, 5);
1678 sub(reg_ws, reg_coff_max);
1679 shl(reg_coff_max, 5);
1680 }
1681 }
1682
1683 void backward_diff_channels() {
1684 Label diff_channels;
1685 L(diff_channels);
1686 {
1687 uni_vmovups_maybe_tail(vmean, mean_ptr());
1688 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
1689 uni_vaddps(vsqrtvar, vsqrtvar, veps);
1690 uni_vsqrtps(vsqrtvar, vsqrtvar);
1691 uni_vdivps(vsqrtvar, vone, vsqrtvar, vtmp);
1692 if (pd_->use_scale()) uni_vmovups_maybe_tail(vgamma, gamma_ptr());
1693 uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
1694 uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
1695 uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
1696 uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
1697 uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
1698
1699 const auto spat_loop_init_fin
1700 = [=](size_t base_reg) { UNUSED(base_reg); };
1701 const auto spat_loop_body = [=](size_t base_reg, size_t i,
1702 bool stream_store_allowed) {
1703 const Vmm v(base_reg * 2 + 0);
1704 const Vmm t(base_reg * 2 + 1);
1705 const Vmm t1(base_reg * 2 + 2);
1706 const size_t offt = i * vlen_spat_data_;
1707 uni_vmovups_spat_data(
1708 v, vmmword[reg_diff_dst + reg_soff + offt]);
1709 if (with_relu) {
1710 if (isa == avx512_core)
1711 bwd_process_relu_avx512_common(v, offt);
1712 else if (isa == avx2)
1713 bwd_process_relu_avx2(v, offt);
1714 else
1715 assert(false);
1716 }
1717 if (!pd_->use_global_stats()) {
1718 uni_vsubps(v, v, vdiff_beta);
1719 uni_vmovups_spat_data(
1720 t, vmmword[reg_src + reg_soff + offt]);
1721 uni_vsubps(t, vmean, t, t1);
1722 uni_vmulps(t, t, vdiff_gamma);
1723 uni_vaddps(v, v, t);
1724 }
1725 uni_vmulps(v, v, vsqrtvar);
1726 if (pd_->use_scale()) { uni_vmulps(v, v, vgamma); }
1727 if (stream_store_allowed) {
1728 uni_vmovntps(vmmword[reg_diff_src + reg_soff + offt], v);
1729 } else {
1730 uni_vmovups_spat_data(
1731 vmmword[reg_diff_src + reg_soff + offt], v);
1732 }
1733 };
1734
1735 const auto compute = [=](bool stream_store_allowed) {
1736 using namespace std::placeholders;
1737 spat_loop(spat_size, unroll_blocks, unroll_regs,
1738 spat_loop_init_fin,
1739 std::bind(spat_loop_body, _1, _2, stream_store_allowed),
1740 spat_loop_init_fin);
1741 };
1742
1743 if (stream_store_supported()) {
1744 Label normal_store, end_store;
1745 test(reg_diff_src, vlen - 1);
1746 jnz(normal_store, T_NEAR);
1747 compute(true);
1748 jmp(end_store, T_NEAR);
1749 L(normal_store);
1750 { compute(false); }
1751 L(end_store);
1752 } else {
1753 compute(false); // no NT store for BF16
1754 }
1755
1756 add(reg_coff, vlen);
1757 cmp(reg_coff, reg_coff_max);
1758 jl(diff_channels);
1759 }
1760 }
1761
1762 void backward_diff_channels_nspc_compute(const int num_ch_blks) {
1763 auto compute = [=](bool stream_store_allowed) {
1764 xor_(reg_soff_nspc, reg_soff_nspc);
1765 if (jbp_->is_spatial_thr_) {
1766 mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
1767 add(reg_soff_nspc, ptr[rsp + stack_off_s_s]);
1768 } else {
1769 mov(reg_ctr, spat_size);
1770 }
1771
1772 // TODO: spatial blocking
1773 const int num_spat_pts = 1;
1774
1775 // pre-compute scale for each channel to avoid costly div and sqrt
1776 if (!pd_->use_global_stats()) {
1777 mov(ptr[rsp + stack_off_ws_off_copy], reg_ws);
1778 mov(reg_ws, ptr[rsp + stack_off_diff_scale]);
1779 }
1780 for (int idx = 0; idx < num_ch_blks; ++idx) {
1781 const int coff = idx * vlen;
1782 const Vmm vsqrtvar_ch = Vmm(idx);
1783 uni_vmovups_maybe_tail(vsqrtvar_ch, var_ptr(coff));
1784 uni_vaddps(vsqrtvar_ch, vsqrtvar_ch, veps);
1785 uni_vsqrtps(vsqrtvar_ch, vsqrtvar_ch);
1786 uni_vdivps(vsqrtvar_ch, vone, vsqrtvar_ch, vtmp);
1787 if (!pd_->use_global_stats()) {
1788 const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks);
1789 const Vmm vdiff_gamma_ch = Vmm(idx + 2 * num_ch_blks);
1790 uni_vmovups_maybe_tail(vdiff_beta_ch,
1791 vmmword[reg_diff_shift + reg_coff + coff]);
1792 uni_vmovups_maybe_tail(
1793 vdiff_gamma_ch, vmmword[reg_ws + reg_coff + coff]);
1794 uni_vdivps(vdiff_beta_ch, vdiff_beta_ch, vchan_size);
1795 uni_vmulps(vdiff_gamma_ch, vdiff_gamma_ch, vsqrtvar_ch);
1796 uni_vdivps(vdiff_gamma_ch, vdiff_gamma_ch, vchan_size);
1797 }
1798 }
1799 if (!pd_->use_global_stats()) {
1800 mov(reg_ws, ptr[rsp + stack_off_ws_off_copy]);
1801 }
1802
1803 Label spatial;
1804 L(spatial);
1805 {
1806 for (int idx = 0; idx < num_ch_blks; ++idx) {
1807 const int coff = idx * vlen;
1808 const int offt = idx * vlen_spat_data_;
1809 // vdiff_beta and vdiff_gamma are free registers for nspc
1810 const Vmm vdiff_data = vdiff_beta;
1811 const Vmm vdata = vdiff_gamma;
1812 const Vmm vsqrtvar_ch = Vmm(idx);
1813 uni_vmovups_maybe_tail(vmean, mean_ptr(coff));
1814
1815 if (pd_->use_scale())
1816 uni_vmovups_maybe_tail(vgamma, gamma_ptr(coff));
1817
1818 uni_vmovups_spat_data(vdiff_data,
1819 vmmword[reg_diff_dst + reg_soff_nspc + offt]);
1820
1821 if (with_relu) {
1822 if (isa == avx512_core)
1823 bwd_process_relu_avx512_common(vdiff_data, offt);
1824 else
1825 assert(false);
1826 }
1827
1828 if (!pd_->use_global_stats()) {
1829 const Vmm vdiff_beta_ch = Vmm(idx + num_ch_blks);
1830 const Vmm vdiff_gamma_ch = Vmm(idx + 2 * num_ch_blks);
1831 uni_vsubps(vdiff_data, vdiff_data, vdiff_beta_ch);
1832 uni_vmovups_spat_data(
1833 vdata, vmmword[reg_src + reg_soff_nspc + offt]);
1834 uni_vsubps(vdata, vmean, vdata, vtmp);
1835 uni_vmulps(vdata, vdata, vdiff_gamma_ch);
1836 uni_vaddps(vdiff_data, vdiff_data, vdata);
1837 }
1838
1839 uni_vmulps(vdiff_data, vdiff_data, vsqrtvar_ch);
1840
1841 if (pd_->use_scale()) {
1842 uni_vmulps(vdiff_data, vdiff_data, vgamma);
1843 }
1844
1845 uni_vmovups_spat_data(
1846 vmmword[reg_diff_src + reg_soff_nspc + offt],
1847 vdiff_data, stream_store_allowed);
1848 }
1849 add(reg_soff_nspc, spat_step);
1850 sub(reg_ctr, num_spat_pts);
1851 jnz(spatial, T_NEAR);
1852 }
1853 };
1854
1855 if (stream_store_supported()) {
1856 Label normal_store, end_store;
1857 test(reg_diff_src, vlen - 1);
1858 jnz(normal_store, T_NEAR);
1859 compute(true);
1860 jmp(end_store, T_NEAR);
1861 L(normal_store);
1862 { compute(false); }
1863 L(end_store);
1864 } else {
1865 compute(false); // disabled for bf16 when data fits in cache
1866 }
1867 }
1868
1869 void backward_diff_channels_nspc() {
1870 xor_(reg_coff, reg_coff);
1871 mov(reg_coff_max_bwd_copy, reg_coff_max);
1872
1873 Label ch_unroll_label[5];
1874 const int max_ch_unroll = 3;
1875
1876 // TODO: Spatial and channel unrolling decisions should be made during
1877 // initialization depending on the problem size
1878 for (int ch_idx = max_ch_unroll; ch_idx > 0; --ch_idx) {
1879 L(ch_unroll_label[ch_idx]);
1880 {
1881 const int ch_blk_size = (1 << (ch_idx - 1)); // 4, 2, 1
1882 cmp(reg_coff_max, vlen * ch_blk_size);
1883 jl(ch_unroll_label[ch_idx - 1], T_NEAR);
1884
1885 backward_diff_channels_nspc_compute(ch_blk_size);
1886
1887 add(reg_diff_dst, vlen_spat_data_ * ch_blk_size);
1888 if (!pd_->use_global_stats())
1889 add(reg_src, vlen_spat_data_ * ch_blk_size);
1890 add(reg_diff_src, vlen_spat_data_ * ch_blk_size);
1891
1892 // advance mean_ptr() and var_ptr()
1893 add(reg_coff, vlen * ch_blk_size);
1894
1895 add(reg_ws, 2 * ch_blk_size);
1896
1897 sub(reg_coff_max, vlen * ch_blk_size);
1898 jmp(ch_unroll_label[ch_idx], T_NEAR);
1899 }
1900 }
1901 L(ch_unroll_label[0]);
1902
1903 // comeback
1904 mov(reg_coff_max, reg_coff_max_bwd_copy);
1905 mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]);
1906
1907 if (is_xf16()) shr(reg_coff_max, 1);
1908 sub(reg_diff_dst, reg_coff_max);
1909 if (!pd_->use_global_stats()) sub(reg_src, reg_coff_max);
1910 sub(reg_diff_src, reg_coff_max);
1911 if (is_xf16()) shl(reg_coff_max, 1);
1912
1913 shr(reg_coff_max, 5);
1914 sub(reg_ws, reg_coff_max);
1915 shl(reg_coff_max, 5);
1916 }
1917
1918 void backward() {
1919 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1920 xor_(reg_coff, reg_coff);
1921 Label zero_rbuf, sh_spatial;
1922
1923 L(zero_rbuf);
1924 {
1925 uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
1926 uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
1927 add(reg_coff, isa == sse41 ? vlen / 2 : vlen);
1928 cmp(reg_coff, reg_coff_max);
1929 jne(zero_rbuf);
1930 }
1931
1932 mov(reg_src, ptr[rsp + stack_off_src]);
1933 mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
1934 if (with_relu) {
1935 assert(isa == avx2 || isa == avx512_core);
1936 mov(reg_ws, ptr[rsp + stack_off_ws]);
1937 }
1938
1939 xor_(reg_soff, reg_soff);
1940 L(sh_spatial);
1941 {
1942 xor_(reg_coff, reg_coff);
1943 if (isa == sse41) { mov(reg_tmp_off, reg_soff); }
1944 jbp_->is_nspc_ ? backward_sh_channels_nspc()
1945 : backward_sh_channels();
1946 if (isa == sse41) {
1947 mov(reg_soff, reg_tmp_off);
1948 add(reg_diff_dst, vlen / 2);
1949 add(reg_src, vlen / 2);
1950 mov(reg_coff, vlen / 2);
1951 backward_sh_channels();
1952 sub(reg_diff_dst, vlen / 2);
1953 sub(reg_src, vlen / 2);
1954 }
1955 // Process next image
1956 if (jbp_->is_nspc_) {
1957 // Can use static offset since we comeback after spatial loop
1958 add(reg_src, mb_offt);
1959 add(reg_diff_dst, mb_offt);
1960 add(reg_soff, mb_offt);
1961 add(reg_ws, ws_mb_offt);
1962 } else {
1963 add(reg_soff, reg_mb_stride_Bc);
1964 }
1965 cmp(reg_soff, reg_soff_max);
1966 jl(sh_spatial);
1967 }
1968
1969 if (jbp_->is_nspc_) {
1970 // comeback
1971 mov(reg_src, ptr[rsp + stack_off_src]);
1972 mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
1973 }
1974
1975 mov(reg_diff_scale, ptr[rsp + stack_off_diff_scale]);
1976 mov(reg_diff_shift, ptr[rsp + stack_off_diff_shift]);
1977
1978 Label no_sh_reduction;
1979 barrier();
1980 {
1981 mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
1982 cmp(reg_tmp, 0);
1983 Label sh_reduction_channels;
1984 jne(no_sh_reduction, T_NEAR);
1985
1986 mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
1987 xor_(reg_coff, reg_coff);
1988 L(sh_reduction_channels);
1989 {
1990 mov(reg_roff, reg_coff);
1991 uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
1992 uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
1993 uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
1994 uni_vaddps(vsqrtvar, vsqrtvar, veps);
1995 uni_vsqrtps(vsqrtvar, vsqrtvar);
1996 uni_vdivps(vsqrtvar, vone, vsqrtvar, vtmp);
1997 mov(reg_ctr, reg_nnthr);
1998 Label sh_reduction_thrs;
1999 L(sh_reduction_thrs);
2000 { // TODO: unroll (?)
2001 uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
2002 uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
2003 add(reg_roff, reg_coff_max);
2004 sub(reg_ctr, 1);
2005 jnz(sh_reduction_thrs);
2006 }
2007 uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
2008 uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
2009 uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
2010 add(reg_coff, isa == sse41 ? vlen / 2 : vlen);
2011 cmp(reg_coff, reg_coff_max);
2012 jne(sh_reduction_channels);
2013 }
2014 }
2015 L(no_sh_reduction);
2016 barrier();
2017
2018 mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
2019 if (with_relu) {
2020 assert(isa == avx2 || isa == avx512_core);
2021 mov(reg_ws, ptr[rsp + stack_off_ws]);
2022 }
2023
2024 xor_(reg_soff, reg_soff);
2025 Label diff_spatial;
2026 L(diff_spatial);
2027 {
2028 xor_(reg_coff, reg_coff);
2029 // diff_shift is shared with soff_max.
2030 mov(reg_diff_shift, ptr[rsp + stack_off_diff_shift]);
2031 if (isa == sse41) { mov(reg_tmp_off, reg_soff); }
2032 jbp_->is_nspc_ ? backward_diff_channels_nspc()
2033 : backward_diff_channels();
2034 if (isa == sse41) {
2035 mov(reg_soff, reg_tmp_off);
2036 add(reg_diff_dst, vlen / 2);
2037 add(reg_diff_src, vlen / 2);
2038 add(reg_src, vlen / 2);
2039 mov(reg_coff, vlen / 2);
2040 backward_diff_channels();
2041 sub(reg_diff_dst, vlen / 2);
2042 sub(reg_diff_src, vlen / 2);
2043 sub(reg_src, vlen / 2);
2044 }
2045 // Process next image
2046 if (jbp_->is_nspc_) {
2047 // Can use static offset since we comeback after spatial loop
2048 if (!pd_->use_global_stats()) add(reg_src, mb_offt);
2049 add(reg_diff_dst, mb_offt);
2050 add(reg_diff_src, mb_offt);
2051 add(reg_soff, mb_offt);
2052 add(reg_ws, ws_mb_offt);
2053 } else {
2054 add(reg_soff, reg_mb_stride_Bc);
2055 }
2056
2057 // comeback soff_max. Shared with diff_shift.
2058 mov(reg_soff_max, ptr[rsp + stack_off_soff_max]);
2059 cmp(reg_soff, reg_soff_max);
2060 jl(diff_spatial);
2061 }
2062 if (jbp_->is_nspc_) {
2063 // comeback
2064 if (!pd_->use_global_stats())
2065 mov(reg_src, ptr[rsp + stack_off_src]);
2066 mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
2067 mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
2068 if (with_relu) mov(reg_ws, ptr[rsp + stack_off_ws]);
2069 }
2070 }
2071
2072 jit_bnorm_t(const batch_normalization_pd_t *pd, const jit_bnorm_conf_t *jbp)
2073 : jit_generator(jit_name()), pd_(pd), jbp_(jbp) {
2074 static_assert(isa == sse41 || isa == avx2 || isa == avx512_core,
2075 "unsupported isa");
2076
2077 is_bf16_ = pd_->src_md()->data_type == data_type::bf16;
2078 is_f16_ = pd_->src_md()->data_type == data_type::f16;
2079 is_avx2_ne_xf16_
2080 = isa == avx2 && mayiuse(avx2_vnni_2) && (is_bf16_ || is_f16_);
2081 vlen_spat_data_ = vlen / (1 + is_xf16()); // 32B of xF16 -> 64B of FP32
2082
2083 unroll_blocks = isa == avx512_core && !jbp_->is_spatial_thr_ ? 4 : 1;
2084 unroll_regs = isa == avx512_core && !jbp_->is_spatial_thr_ ? 4 : 1;
2085 }
2086
2087 void generate() override {
2088 preamble();
2089
2090 if (use_bf16_emulation()) {
2091 // init emulation of bfloat16 operations
2092 bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserved_1,
2093 bf16_emu_reserved_2, bf16_emu_reserved_3, reg_bf16_tmp,
2094 bf16_emu_reserved_4, bf16_emu_reserved_4);
2095 bf16_emu_->init_vcvtneps2bf16();
2096 }
2097
2098 if (isa == avx512_core)
2099 prepare_tail_mask_avx512_common();
2100 else if (isa == avx2)
2101 prepare_tail_mask_avx2_common();
2102
2103 compute_static_strides();
2104
2105 prepare_relu();
2106
2107 sub(rsp, stack_size_required);
2108 load_common_params();
2109
2110 if (pd_->is_fwd()) {
2111 if (!pd_->stats_is_src()) { compute_mean_variance(); }
2112 forward();
2113 } else {
2114 backward();
2115 }
2116 add(rsp, stack_size_required);
2117 postamble();
2118 }
2119
2120 void operator()(const call_params_t *p) { jit_generator::operator()(p); }
2121
2122 ~jit_bnorm_t() override { delete bf16_emu_; }
2123};
2124
2125namespace bnorm_impl {
2126
2127template <cpu_isa_t isa>
2128struct driver_t : public c_compatible {
2129 driver_t(const batch_normalization_pd_t *pd, int nthr)
2130 : pd_(pd), jbp_(pd_, nthr, simd_w), ker_(pd_, &jbp_) {}
2131
2132 ~driver_t() = default;
2133
2134 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
2135 const batch_normalization_pd_t *pd, int nthr) {
2136 dim_t C_PADDED = get_c_padded(pd);
2137
2138 auto sbuf_sz = use_tmp_stats(pd) * 2 * C_PADDED;
2139 auto pbuf_sz
2140 = (use_tmp_diff_scale(pd) + use_tmp_diff_shift(pd)) * C_PADDED;
2141 auto rbuf_sz = (pd->is_fwd() ? 1 : 2) * C_PADDED * nthr;
2142
2143 scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz);
2144 scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz);
2145 scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz);
2146
2147 if (dnnl_thr_syncable()) {
2148 auto n_barriers = C_PADDED / simd_w;
2149 scratchpad.book<barrier::ctx_64_t>(key_barrier, n_barriers);
2150 }
2151 }
2152
2153 // given nthr, shape of problem, and thread partition,
2154 // balance work among the threads
2155 void thread_balance(int ithr, int nthr, dim_t N, dim_t C_blks, dim_t SP,
2156 int &C_ithr, int C_nthr, dim_t &C_blk_s, dim_t &C_blk_e,
2157 int &N_ithr, int N_nthr, dim_t &N_s, dim_t &N_e, int &S_ithr,
2158 int S_nthr, dim_t &S_s, dim_t &S_e) {
2159 if (ithr < C_nthr * N_nthr * S_nthr) {
2160 utils::nd_iterator_init(
2161 ithr, C_ithr, C_nthr, N_ithr, N_nthr, S_ithr, S_nthr);
2162 balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
2163 balance211(N, N_nthr, N_ithr, N_s, N_e);
2164 balance211(SP, S_nthr, S_ithr, S_s, S_e);
2165 } else {
2166 S_ithr = N_ithr = C_ithr = -ithr;
2167 S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1;
2168 }
2169 }
2170
2171 void exec(int ithr, int nthr, const void *src, void *diff_src, void *dst,
2172 const void *diff_dst, const acc_data_t *scale,
2173 acc_data_t *diff_scale, const acc_data_t *shift,
2174 acc_data_t *diff_shift, const acc_data_t *mean,
2175 const acc_data_t *var, const uint8_t *ws,
2176 const memory_tracking::grantor_t &scratchpad) {
2177 auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
2178 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2179 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
2180 auto barriers = scratchpad.get<barrier::ctx_64_t>(key_barrier);
2181
2182 dim_t N = pd_->MB();
2183 dim_t C = pd_->C();
2184 dim_t C_PADDED = get_c_padded(pd_);
2185 dim_t D = pd_->D();
2186 dim_t H = pd_->H();
2187 dim_t W = pd_->W();
2188 dim_t SP = D * H * W;
2189 dim_t img_size = C_PADDED * SP;
2190 const int vlen_spat_data = ker_.spat_step;
2191
2192 typename jit_bnorm_t<isa>::call_params_t p;
2193
2194 p.eps = pd_->desc()->batch_norm_epsilon;
2195 p.one = 1.0f;
2196 p.spat_size = SP;
2197 p.chan_size = 1.0f * N * p.spat_size;
2198
2199 int C_ithr {0}, N_ithr {0}, S_ithr {0};
2200 dim_t C_blk_s {0}, C_blk_e {0}, N_s {0}, N_e {0}, S_s {0}, S_e {0};
2201
2202 this->thread_balance(ithr, nthr, N, jbp_.C_blks_per_iter_, SP, C_ithr,
2203 jbp_.C_nthr_, C_blk_s, C_blk_e, N_ithr, jbp_.N_nthr_, N_s, N_e,
2204 S_ithr, jbp_.S_nthr_, S_s, S_e);
2205
2206 int SP_N_ithr = N_ithr * jbp_.S_nthr_ + S_ithr;
2207 int SP_N_nthr = jbp_.N_nthr_ * jbp_.S_nthr_;
2208 assert(IMPLICATION(!dnnl_thr_syncable(), SP_N_nthr == 1));
2209
2210 p.N_ithr = SP_N_ithr;
2211 p.N_nthr = SP_N_nthr;
2212
2213 int global_C_blk_s;
2214 int global_barriers_per_iter = jbp_.C_nthr_;
2215
2216 for (int64_t it = 0; it < jbp_.iters_; it++) {
2217 if (it == jbp_.iters_ - 1 && jbp_.iters_ > 1) {
2218 C_blk_s = C_blk_e = N_s = N_e = 0;
2219 this->thread_balance(ithr, nthr, N, jbp_.C_blks_last_iter_, SP,
2220 C_ithr, jbp_.C_nthr_last_iter_, C_blk_s, C_blk_e,
2221 N_ithr, jbp_.N_nthr_last_iter_, N_s, N_e, S_ithr,
2222 jbp_.S_nthr_last_iter_, S_s, S_e);
2223
2224 // Update call parameters for JIT, last iteration
2225 p.N_ithr = N_ithr * jbp_.S_nthr_last_iter_ + S_ithr;
2226 p.N_nthr = jbp_.N_nthr_last_iter_ * jbp_.S_nthr_last_iter_;
2227 }
2228
2229 global_C_blk_s = jbp_.do_blocking_ ? (C_blk_s == -1)
2230 ? -1
2231 : it * jbp_.C_blks_per_iter_ + C_blk_s
2232 : C_blk_s;
2233
2234 int C_blks_thr = C_blk_e - C_blk_s;
2235 int N_thr = N_e - N_s;
2236
2237 if (C_blks_thr == 0 || N_thr == 0) continue;
2238
2239 size_t coff_base = global_C_blk_s * simd_w;
2240 size_t soff_base = jbp_.is_nspc_
2241 ? coff_base + N_s * img_size
2242 : global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
2243 size_t shift_off = use_tmp_diff_scale(pd_) ? pd_->C() : 0;
2244
2245 p.spat_size_loc = S_e - S_s;
2246 p.S_s = S_s * vlen_spat_data;
2247 p.S_tail = (p.spat_size - S_e) * vlen_spat_data;
2248 p.coff_max = C_blks_thr * simd_w;
2249 const auto tmp_mean = use_tmp_stats(pd_) ? sbuf : mean;
2250 if (tmp_mean != nullptr) p.mean = tmp_mean + coff_base;
2251 const auto tmp_var = use_tmp_stats(pd_) ? sbuf + C_PADDED : var;
2252 if (tmp_var != nullptr) p.var = tmp_var + coff_base;
2253 if (scale != nullptr) p.scale = scale + coff_base;
2254 if (shift != nullptr) p.shift = shift + coff_base;
2255 const auto tmp_diff_scale
2256 = use_tmp_diff_scale(pd_) ? pbuf : diff_scale;
2257 if (tmp_diff_scale != nullptr)
2258 p.diff_scale = tmp_diff_scale + coff_base;
2259 const auto tmp_diff_shift
2260 = use_tmp_diff_shift(pd_) ? &pbuf[shift_off] : diff_shift;
2261 if (tmp_diff_shift != nullptr)
2262 p.diff_shift = tmp_diff_shift + coff_base;
2263
2264 p.soff_max = jbp_.dt_size_ * N_thr * img_size;
2265 if (src != nullptr)
2266 p.src = (void *)((char *)src + soff_base * jbp_.dt_size_);
2267 if (dst != nullptr)
2268 p.dst = (void *)((char *)dst + soff_base * jbp_.dt_size_);
2269 if (diff_src != nullptr)
2270 p.diff_src = (void *)((char *)diff_src
2271 + soff_base * jbp_.dt_size_);
2272 if (diff_dst != nullptr)
2273 p.diff_dst = (void *)((char *)diff_dst
2274 + soff_base * jbp_.dt_size_);
2275 if (ws != nullptr) p.ws = ws + soff_base / 8;
2276
2277 p.mb_stride_Bc
2278 = jbp_.dt_size_ * (img_size - p.coff_max * p.spat_size);
2279
2280 // use SP_N_nthr which is the same as p.N_nthr except maybe for
2281 // the last iteration.
2282 p.rbuf1 = rbuf
2283 + ((it * jbp_.C_blks_per_iter_) * SP_N_nthr
2284 + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr)
2285 * simd_w;
2286 // rbuf1 and rbuf2 have to be disjoint
2287 p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
2288 p.is_cblk_tail
2289 = (it * jbp_.C_blks_per_iter_ + C_blk_e) * simd_w > C;
2290
2291 size_t iter_barriers
2292 = jbp_.do_blocking_ ? it * global_barriers_per_iter : 0;
2293 p.barrier = barriers + C_ithr + iter_barriers;
2294 if (p.soff_max != 0 && p.coff_max != 0) ker_(&p);
2295 }
2296 }
2297
2298 void init_barriers(const memory_tracking::grantor_t &scratchpad) {
2299 auto barriers = scratchpad.get<barrier::ctx_64_t>(key_barrier);
2300 if (barriers) {
2301 const int n_barriers = get_c_padded(pd_) / simd_w;
2302 for (int i = 0; i < n_barriers; ++i)
2303 barrier::ctx_init(&barriers[i]);
2304 }
2305 }
2306
2307 status_t create_kernel() { return ker_.create_kernel(); }
2308
2309private:
2310 enum {
2311 simd_w = isa == sse41 ? 8
2312 : cpu_isa_traits<isa>::vlen
2313 / sizeof(acc_data_t) // BF16 will expand to FP32
2314 };
2315
2316 static bool use_tmp_stats(const batch_normalization_pd_t *pd) {
2317 return !pd->stats_is_src()
2318 && pd->desc()->prop_kind == prop_kind::forward_inference;
2319 }
2320
2321 static bool use_tmp_diff_scale(const batch_normalization_pd_t *pd) {
2322 return (!pd->is_fwd() && !pd->use_scale())
2323 || pd->desc()->prop_kind == prop_kind::backward_data;
2324 }
2325
2326 static bool use_tmp_diff_shift(const batch_normalization_pd_t *pd) {
2327 return (!pd->is_fwd() && !pd->use_shift())
2328 || pd->desc()->prop_kind == prop_kind::backward_data;
2329 }
2330
2331 const batch_normalization_pd_t *pd_;
2332 jit_bnorm_conf_t jbp_;
2333 jit_bnorm_t<isa> ker_;
2334};
2335} // namespace bnorm_impl
2336
2337using namespace data_type;
2338using namespace format_tag;
2339using namespace utils;
2340
2341/* fwd */
2342
2343template <cpu_isa_t isa>
2344status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init(engine_t *engine) {
2345 bool ok = is_fwd() && mayiuse(isa)
2346 && !has_zero_dim_memory()
2347 // Algorithm requires barriers for best performance.
2348 // TBB utilizes jit_uni_tbb_batch_normalization implementation.
2349 && dnnl_thr_syncable()
2350 && one_of(src_md()->data_type, f32, bf16, f16)
2351 && src_md()->data_type == dst_md()->data_type
2352 && IMPLICATION(src_md()->data_type == bf16,
2353 is_superset(isa, avx512_core)
2354 || (isa == avx2 && mayiuse(avx2_vnni_2)))
2355 // Note: re-using avx512_core/avx2 implementation for f16.
2356 // This is okay as currently, we do not support binary post-ops
2357 // for this primitive.
2358 && IMPLICATION(src_md()->data_type == f16,
2359 (is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16))
2360 || (isa == avx2 && mayiuse(avx2_vnni_2)))
2361 && check_scale_shift_data_type()
2362 && (attr()->has_default_values()
2363 || with_relu_post_op(is_training()))
2364 && set_default_formats_common()
2365 && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md());
2366 if (!ok) return status::unimplemented;
2367
2368 // BN+Add+Relu fusion is not currently implemented
2369 if (fuse_norm_add_relu()) return status::unimplemented;
2370
2371 const memory_desc_wrapper src_d(src_md());
2372 if (isa == avx512_core) {
2373 if (!src_d.matches_one_of_tag(
2374 nCw16c, nChw16c, nCdhw16c, nc, nwc, nhwc, ndhwc))
2375 return status::unimplemented;
2376 } else if (isa == avx2 && one_of(src_md()->data_type, bf16, f16)) {
2377 // no support for training or blocked layouts for avx2_vnni_2
2378 if (is_training() || src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc))
2379 return status::unimplemented;
2380 } else if (isa == avx2) {
2381 // full support
2382 if (!src_d.matches_one_of_tag(
2383 nCw8c, nChw8c, nCdhw8c, nc, nwc, nhwc, ndhwc))
2384 return status::unimplemented;
2385 } else {
2386 if (!src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c))
2387 return status::unimplemented;
2388 }
2389
2390 const bool isa_supports_avx2 = is_superset(isa, avx2);
2391 if (is_training() && fuse_norm_relu()) {
2392 if (!isa_supports_avx2) return status::unimplemented;
2393 init_default_ws(1);
2394 }
2395
2396 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2397 && !isa_supports_avx2)
2398 return status::unimplemented;
2399
2400 // Only IC % simd_w == 0 is supported for now
2401 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(acc_data_t);
2402 if (src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc)
2403 && src_d.padded_dims()[1] % simd_w != 0) {
2404 return status::unimplemented;
2405 }
2406
2407 nthr_ = dnnl_get_max_threads();
2408 auto scratchpad = scratchpad_registry().registrar();
2409 bnorm_impl::driver_t<isa>::init_scratchpad(scratchpad, this, nthr_);
2410
2411 return status::success;
2412}
2413
2414template <cpu_isa_t isa>
2415jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
2416 const pd_t *apd)
2417 : primitive_t(apd) {}
2418
2419template <cpu_isa_t isa>
2420status_t jit_uni_batch_normalization_fwd_t<isa>::init(engine_t *engine) {
2421 CHECK(safe_ptr_assign(
2422 bnorm_driver_, new bnorm_impl::driver_t<isa>(pd(), pd()->nthr_)));
2423 return bnorm_driver_->create_kernel();
2424}
2425
2426template <cpu_isa_t isa>
2427status_t jit_uni_batch_normalization_fwd_t<isa>::execute(
2428 const exec_ctx_t &ctx) const {
2429
2430 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2431 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
2432 auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT);
2433
2434 auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>(
2435 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN))
2436 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
2437 auto var = pd()->stats_is_src()
2438 ? const_cast<acc_data_t *>(
2439 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE))
2440 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
2441 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
2442 auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
2443
2444 auto scratchpad = ctx.get_scratchpad_grantor();
2445
2446 bnorm_driver_->init_barriers(scratchpad);
2447 const int nthr = pd()->nthr_;
2448
2449 parallel(nthr, [&](const int ithr, const int nthr) {
2450 bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, scale,
2451 nullptr, shift, nullptr, mean, var, ws, scratchpad);
2452 });
2453
2454 return status::success;
2455}
2456
2457template <cpu_isa_t isa>
2458jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t() {
2459 delete bnorm_driver_;
2460}
2461
2462template <cpu_isa_t isa>
2463status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
2464 bool ok = !is_fwd() && mayiuse(isa)
2465 && !has_zero_dim_memory()
2466 // Algorithm requires barriers for best performance.
2467 // TBB utilizes jit_uni_tbb_batch_normalization implementation.
2468 && dnnl_thr_syncable()
2469 && one_of(src_md()->data_type, f32, bf16, f16)
2470 && src_md()->data_type == diff_src_md()->data_type
2471 && diff_src_md()->data_type == diff_dst_md()->data_type
2472 && IMPLICATION(
2473 src_md()->data_type == bf16, is_superset(isa, avx512_core))
2474 // Note: re-using avx512_core implementation for f16. This is okay
2475 // as currently, we do not support binary post-ops for this
2476 // primitive.
2477 && IMPLICATION(src_md()->data_type == f16,
2478 is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16))
2479 && check_scale_shift_data_type() && attr()->has_default_values()
2480 && set_default_formats_common()
2481 && memory_desc_wrapper(diff_src_md())
2482 == memory_desc_wrapper(diff_dst_md());
2483 if (!ok) return status::unimplemented;
2484
2485 // BN+Add+Relu fusion is not currently implemented
2486 if (fuse_norm_add_relu()) return status::unimplemented;
2487
2488 const memory_desc_wrapper src_d(src_md());
2489 const memory_desc_wrapper diff_src_d(diff_src_md());
2490
2491 format_tag_t src_tag, diff_src_tag;
2492 if (isa == avx512_core) {
2493 src_tag = src_d.matches_one_of_tag(
2494 nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
2495 diff_src_tag = diff_src_d.matches_one_of_tag(
2496 nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
2497 } else {
2498 src_tag = src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
2499 diff_src_tag = diff_src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
2500 }
2501 ok = (src_tag != format_tag::undef && diff_src_tag != format_tag::undef
2502 && src_tag == diff_src_tag);
2503 if (!ok) return status::unimplemented;
2504
2505 const bool isa_supports_avx2 = is_superset(isa, avx2);
2506 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2507 && !isa_supports_avx2)
2508 return status::unimplemented;
2509
2510 // Only IC % 16 == 0 is supported for now
2511 if (src_d.matches_one_of_tag(nc, nwc, nhwc, ndhwc)
2512 && src_d.padded_dims()[1] % 16 != 0) {
2513 return status::unimplemented;
2514 }
2515
2516 if (fuse_norm_relu()) {
2517 if (!isa_supports_avx2) return status::unimplemented;
2518 init_default_ws(1);
2519 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
2520 }
2521
2522 /* TODO: extra checks required */
2523
2524 nthr_ = dnnl_get_max_threads();
2525 auto scratchpad = scratchpad_registry().registrar();
2526 bnorm_impl::driver_t<isa>::init_scratchpad(scratchpad, this, nthr_);
2527
2528 return status::success;
2529}
2530
2531template <cpu_isa_t isa>
2532jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
2533 const pd_t *apd)
2534 : primitive_t(apd) {}
2535
2536template <cpu_isa_t isa>
2537status_t jit_uni_batch_normalization_bwd_t<isa>::init(engine_t *engine) {
2538 CHECK(safe_ptr_assign(
2539 bnorm_driver_, new bnorm_impl::driver_t<isa>(pd(), pd()->nthr_)));
2540 return bnorm_driver_->create_kernel();
2541}
2542
2543template <cpu_isa_t isa>
2544status_t jit_uni_batch_normalization_bwd_t<isa>::execute(
2545 const exec_ctx_t &ctx) const {
2546 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2547 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
2548 auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
2549 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
2550 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
2551 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
2552
2553 auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
2554 auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE);
2555 auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT);
2556
2557 auto scratchpad = ctx.get_scratchpad_grantor();
2558
2559 bnorm_driver_->init_barriers(scratchpad);
2560 const int nthr = pd()->nthr_;
2561
2562 parallel(nthr, [&](const int ithr, const int nthr) {
2563 bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, scale,
2564 diff_scale, nullptr, diff_shift, mean, var, ws, scratchpad);
2565 });
2566
2567 return status::success;
2568}
2569
2570template <cpu_isa_t isa>
2571jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() {
2572 delete bnorm_driver_;
2573}
2574
2575/* struct instantiation */
2576template struct jit_uni_batch_normalization_fwd_t<sse41>;
2577template struct jit_uni_batch_normalization_bwd_t<sse41>;
2578template struct jit_uni_batch_normalization_fwd_t<avx2>;
2579template struct jit_uni_batch_normalization_bwd_t<avx2>;
2580template struct jit_uni_batch_normalization_fwd_t<avx512_core>;
2581template struct jit_uni_batch_normalization_bwd_t<avx512_core>;
2582
2583} // namespace x64
2584} // namespace cpu
2585} // namespace impl
2586} // namespace dnnl
2587