1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18#include <cmath>
19#include <memory>
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/nstl.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_batch_normalization_utils.hpp"
30#include "cpu/platform.hpp"
31#include "cpu/x64/jit_generator.hpp"
32
33#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
34#include "cpu/x64/jit_uni_tbb_batch_normalization.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace memory_tracking::names;
42using namespace Xbyak;
43using acc_data_t = float;
44
45constexpr int bits_per_byte = 8;
46
47bool normalize_only(const batch_normalization_pd_t *pd) {
48 return pd->stats_is_src() && pd->is_fwd();
49}
50
51dim_t get_c_padded(const batch_normalization_pd_t *pd) {
52 return pd->src_md()->padded_dims[1];
53}
54
55template <cpu_isa_t isa>
56int get_vlen(jit_memory_tag_kind_t tag_kind) {
57 return isa == sse41 && tag_kind == jit_memory_tag_kind_t::blocked
58 ? 32
59 : cpu_isa_traits<isa>::vlen;
60}
61
62template <cpu_isa_t isa>
63int get_simd_w(jit_memory_tag_kind_t tag_kind) {
64 return get_vlen<isa>(tag_kind) / sizeof(acc_data_t);
65}
66
67template <cpu_isa_t isa>
68bool is_avx2_ne_xf16(const batch_normalization_pd_t *pd) {
69 return isa == avx2 && mayiuse(avx2_vnni_2)
70 && utils::one_of(
71 pd->src_md()->data_type, data_type::bf16, data_type::f16);
72}
73
74template <cpu_isa_t isa>
75std::tuple<dim_t, dim_t, dim_t> get_data_strides(
76 const batch_normalization_pd_t *pd, jit_memory_tag_kind_t tag_kind) {
77 const int simd_w = get_simd_w<isa>(tag_kind);
78 size_t stride_N, stride_S, stride_C;
79
80 if (tag_kind == jit_memory_tag_kind_t::nspc) {
81 stride_C = static_cast<size_t>(simd_w);
82 stride_S = static_cast<size_t>(pd->C());
83 stride_N = static_cast<size_t>(pd->D() * pd->H() * pd->W()) * stride_S;
84 } else {
85 const size_t C_blks = static_cast<size_t>(get_c_padded(pd) / simd_w);
86
87 stride_C = static_cast<size_t>(pd->D() * pd->H() * pd->W() * simd_w);
88 stride_S = static_cast<size_t>(simd_w);
89 stride_N = C_blks * stride_C;
90 }
91
92 return std::make_tuple(stride_N, stride_S, stride_C);
93}
94
95#define PARAM_ADDR(x) (reg_param_ + offsetof(call_params_t, x))
96template <cpu_isa_t isa>
97struct jit_bnorm_process_tail_t {
98 using Vmm = typename cpu_isa_traits<isa>::Vmm;
99
100 jit_bnorm_process_tail_t(const batch_normalization_pd_t *pd,
101 jit_generator *host, Reg64 reg_tmp, Reg64 reg_blk_has_tail,
102 Reg64 reg_C, Vmm vtail_mask, Opmask ktail_mask)
103 : h_(host)
104 , reg_tmp_(reg_tmp)
105 , reg_blk_has_tail_(reg_blk_has_tail)
106 , reg_C_(reg_C)
107 , vtail_mask_(vtail_mask)
108 , ktail_mask_(ktail_mask) {
109 const memory_desc_wrapper data_d(pd->src_md());
110 c_is_padded_ = pd->C() != data_d.padded_dims()[1];
111
112 const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen;
113 tail_ = pd->C() % (int)(vlen / sizeof(float));
114 }
115
116 jit_generator *const h_;
117 const Reg64 reg_tmp_;
118 const Reg64 reg_blk_has_tail_;
119 const Reg64 reg_C_;
120 const Vmm vtail_mask_;
121 const Opmask ktail_mask_;
122 bool c_is_padded_;
123 int tail_;
124
125 void prepare_tail_mask_avx512_common() {
126 if (!c_is_padded_) return;
127
128 const int mask = (1 << tail_) - 1;
129
130 Reg32 regw_tmp = reg_tmp_.cvt32();
131 h_->mov(regw_tmp, mask);
132 h_->kmovw(ktail_mask_, regw_tmp);
133 }
134
135 void prepare_tail_mask_avx2_common() {
136 if (!c_is_padded_) return;
137
138 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
139 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
140 0, 0, 0, 0, 0, 0, 0};
141
142 h_->mov(reg_tmp_, reinterpret_cast<size_t>(&mask[8 - tail_]));
143 h_->vmovups(vtail_mask_, h_->ptr[reg_tmp_]);
144 }
145
146 void prepare_tail() {
147 if (isa == avx512_core)
148 prepare_tail_mask_avx512_common();
149 else if (isa == avx2)
150 prepare_tail_mask_avx2_common();
151 }
152
153 void uni_vmovups_tail_avx2_common(
154 const Operand &dst, const Operand &src, Label &l_ret) {
155 if (dst.isMEM()) {
156 h_->vmaskmovps(dst.getAddress(), vtail_mask_, Vmm(src.getIdx()));
157 } else {
158 h_->vmaskmovps(Vmm(dst.getIdx()), vtail_mask_, src.getAddress());
159 }
160 h_->jmp(l_ret);
161 }
162
163 void uni_vmovups_tail_avx512_common(
164 const Operand &dst, const Operand &src, Label &l_ret) {
165 if (dst.isMEM())
166 h_->uni_vmovups(dst.getAddress() | ktail_mask_ | h_->T_z,
167 Vmm(src.getIdx()));
168 else
169 h_->uni_vmovups(Vmm(dst.getIdx()) | ktail_mask_ | h_->T_z,
170 src.getAddress());
171
172 h_->jmp(l_ret);
173 }
174
175 void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
176 Label l_no_mask, l_ret;
177 if (c_is_padded_) {
178 h_->cmp(reg_blk_has_tail_, 0);
179 h_->jz(l_no_mask);
180
181 h_->cmp(reg_C_, 1);
182 h_->jne(l_no_mask);
183 assert(isa == avx512_core || isa == avx2);
184 if (isa == avx512_core)
185 uni_vmovups_tail_avx512_common(dst, src, l_ret);
186 else if (isa == avx2)
187 uni_vmovups_tail_avx2_common(dst, src, l_ret);
188 }
189 h_->L(l_no_mask);
190 if (dst.isMEM())
191 h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
192 else
193 h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
194
195 h_->L(l_ret);
196 }
197};
198
199template <cpu_isa_t isa>
200struct jit_bnorm_process_relu_t {
201 using Vmm = typename cpu_isa_traits<isa>::Vmm;
202
203 jit_bnorm_process_relu_t(const batch_normalization_pd_t *pd,
204 jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp,
205 Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask,
206 Vmm valpha, Vmm vmask, Reg64 reg_alpha)
207 : h_(host)
208 , reg_off_dat_(reg_off_dat)
209 , reg_tmp_(reg_tmp)
210 , reg_ptr_ws_(reg_ptr_ws)
211 , reg_alpha(reg_alpha)
212 , vzero_(vzero)
213 , vstore_mask_(vstore_mask)
214 , kstore_mask_(kstore_mask)
215 , valpha(valpha)
216 , vmask(vmask)
217 , with_relu_(pd->with_relu_post_op(pd->is_training())
218 || pd->fuse_norm_relu())
219 , with_relu_inf_only_(
220 with_relu_ && !(pd->fuse_norm_relu() && pd->is_training()))
221 , bit_shift_(static_cast<int>(log2(bits_per_byte
222 * types::data_type_size(pd->src_md()->data_type))))
223 , alpha(with_relu_inf_only_ && pd->with_relu_post_op(pd->is_training())
224 ? pd->alpha()
225 : 0.f) {}
226
227 jit_bnorm_process_relu_t(const batch_normalization_pd_t *pd,
228 jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp,
229 Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask)
230 : jit_bnorm_process_relu_t(pd, host, reg_off_dat, reg_tmp, reg_ptr_ws,
231 vzero, vstore_mask, kstore_mask, Vmm(), Vmm(), Reg64()) {}
232
233 jit_generator *const h_;
234 const Reg64 reg_off_dat_;
235 const Reg64 reg_tmp_;
236 const Reg64 reg_ptr_ws_;
237 const Reg64 reg_alpha;
238 const Vmm vzero_, vstore_mask_;
239 const Opmask kstore_mask_;
240 // used for ReLU computation
241 const Vmm valpha;
242 const Vmm vmask; // used for AVX2 and SSE41
243 Label l_relu_mask_avx2_;
244 const bool with_relu_, with_relu_inf_only_;
245 const int bit_shift_;
246 const float alpha;
247
248 bool with_relu() const { return with_relu_; }
249
250 bool with_relu_inf_only() const { return with_relu_inf_only_; }
251
252 void fwd_prepare_relu() {
253 if (with_relu_) { h_->uni_vpxor(vzero_, vzero_, vzero_); }
254 if (with_relu_inf_only_ && alpha != 0)
255 h_->mov(reg_alpha, float2int(alpha));
256 }
257
258 void bwd_prepare_relu() {
259 if (with_relu_) {
260 h_->uni_vpxor(vzero_, vzero_, vzero_);
261 if (isa == avx2) prepare_l_relu_mask_avx2();
262 }
263 }
264
265 void prepare_l_relu_mask_avx2() {
266 Label l_mask_after;
267 h_->jmp(l_mask_after);
268 h_->align(32);
269 h_->L(l_relu_mask_avx2_); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
270 for (int i = 0; i < 8; ++i)
271 h_->dd(1 << i);
272 h_->L(l_mask_after);
273 }
274
275 void fwd_process_relu(Vmm v, const int off = 0) {
276 if (with_relu_inf_only_) {
277 if (alpha != 0.f)
278 fwd_process_relu_alpha(v);
279 else
280 h_->uni_vmaxps(v, v, vzero_);
281 } else if (with_relu_) {
282 if (isa == avx512_core)
283 fwd_process_relu_avx512_common(v, off);
284 else if (isa == avx2)
285 fwd_process_relu_avx2(v, off);
286 else
287 assert(false);
288 }
289 }
290
291 void bwd_process_relu(Vmm v, const int off = 0) {
292 if (with_relu_) {
293 if (isa == avx512_core)
294 bwd_process_relu_avx512_common(v, off);
295 else if (isa == avx2)
296 bwd_process_relu_avx2(v, off);
297 else
298 assert(false);
299 }
300 }
301
302 void fwd_process_relu_avx2(Vmm vdst, const int off = 0) {
303 Reg64 reg_store_mask = reg_tmp_;
304 h_->shr(reg_off_dat_, bit_shift_);
305 h_->vcmpps(vstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
306 h_->vmovmskps(reg_store_mask, vstore_mask_);
307 h_->mov(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off],
308 reg_store_mask.cvt8());
309 h_->vblendvps(vdst, vzero_, vdst, vstore_mask_);
310 h_->shl(reg_off_dat_, bit_shift_);
311 }
312
313 void fwd_process_relu_avx512_common(Vmm vdst, const int off = 0) {
314 h_->shr(reg_off_dat_, bit_shift_);
315 h_->vcmpps(kstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
316 h_->kmovw(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off], kstore_mask_);
317 h_->vblendmps(vdst | kstore_mask_, vzero_, vdst);
318 h_->shl(reg_off_dat_, bit_shift_);
319 }
320
321 void fwd_process_relu_alpha(Vmm vmm_dst) {
322 if (isa == avx512_core)
323 fwd_process_relu_alpha_avx512_common(vmm_dst);
324 else {
325 assert(utils::one_of(isa, avx2, sse41));
326 fwd_process_relu_alpha_avx2(vmm_dst);
327 }
328 }
329
330 void fwd_process_relu_alpha_avx512_common(Vmm vmm_dst) {
331 const Xmm xmm_aux = Xmm(valpha.getIdx());
332 h_->vmovq(xmm_aux, reg_alpha);
333 h_->vbroadcastss(valpha, xmm_aux);
334 h_->vcmpps(kstore_mask_, vzero_, vmm_dst, h_->_cmp_lt_os);
335 h_->vmulps(valpha, vmm_dst, valpha);
336 h_->vblendmps(vmm_dst | kstore_mask_, valpha, vmm_dst);
337 }
338
339 void fwd_process_relu_alpha_avx2(Vmm vmm_dst) {
340 const Xmm xmm_aux = Xmm(valpha.getIdx());
341 h_->uni_vpxor(vmask, vmask, vmask);
342 h_->uni_vmovq(xmm_aux, reg_alpha);
343 h_->uni_vbroadcastss(valpha, xmm_aux);
344 h_->uni_vcmpps(vmask, vmm_dst, vzero_, h_->_cmp_lt_os);
345 h_->uni_vmulps(valpha, valpha, vmm_dst);
346 h_->uni_vblendvps(
347 vmm_dst, vmm_dst, valpha, vmask); // swaped aux and dst
348 }
349
350 void bwd_process_relu_avx2(Vmm vdiff_dst, const int off = 0) {
351 h_->shr(reg_off_dat_, bit_shift_);
352 h_->vpbroadcastb(
353 vstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
354 h_->vpand(vstore_mask_, vstore_mask_,
355 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
356 h_->vpcmpeqd(vstore_mask_, vstore_mask_,
357 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
358 h_->vblendvps(vdiff_dst, vzero_, vdiff_dst, vstore_mask_);
359 h_->shl(reg_off_dat_, bit_shift_);
360 }
361
362 void bwd_process_relu_avx512_common(Vmm vdiff_dst, const int off = 0) {
363 h_->shr(reg_off_dat_, bit_shift_);
364 h_->kmovw(kstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
365 h_->vmovups(vdiff_dst | kstore_mask_ | h_->T_z, vdiff_dst);
366 h_->shl(reg_off_dat_, bit_shift_);
367 }
368};
369
370template <cpu_isa_t isa>
371struct helper_vmovups_data_t {
372 using Vmm = typename cpu_isa_traits<isa>::Vmm;
373
374 helper_vmovups_data_t(const batch_normalization_pd_t *pd,
375 jit_generator *host, Zmm zmm_reserved_1, Zmm zmm_reserved_2,
376 Zmm zmm_reserved_3, Zmm zmm_reserved_4, Reg64 reg_tmp)
377 : h_(host), bf16_emu_(nullptr) {
378 is_bf16_ = pd->src_md()->data_type == data_type::bf16;
379 is_f16_ = pd->src_md()->data_type == data_type::f16;
380 if (is_bf16_ && isa == avx512_core && !mayiuse(avx512_core_bf16)) {
381 bf16_emu_ = utils::make_unique<bf16_emulation_t>(h_, zmm_reserved_1,
382 zmm_reserved_2, zmm_reserved_3, reg_tmp, zmm_reserved_4,
383 zmm_reserved_4);
384 }
385 }
386
387 jit_generator *const h_;
388 std::unique_ptr<bf16_emulation_t> bf16_emu_;
389 bool is_bf16_;
390 bool is_f16_;
391
392 void merge_interleaved_to_plain(const Vmm &vmm_even, const Vmm &vmm_odd,
393 const Vmm &vmm_aux0) const {
394 Ymm ymm_even = Ymm(vmm_even.getIdx());
395 Ymm ymm_odd = Ymm(vmm_odd.getIdx());
396 Ymm ymm_aux0 = Ymm(vmm_aux0.getIdx());
397 Ymm ymm_aux1 = ymm_odd;
398
399 h_->vpunpckldq(ymm_aux0, ymm_even, ymm_odd);
400 h_->vpunpckhdq(ymm_aux1, ymm_even, ymm_odd);
401 h_->vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20);
402 h_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31);
403 }
404
405 void operator()(const Vmm &vmm_even, const Vmm &vmm_odd,
406 const Address &addr) const {
407 // load two simd_w data from addr into two registers
408 if (is_bf16_) {
409 // convert bf16 input to f32
410 h_->vcvtneebf162ps(vmm_even, addr);
411 h_->vcvtneobf162ps(vmm_odd, addr);
412 } else if (is_f16_) {
413 h_->vcvtneeph2ps(vmm_even, addr);
414 h_->vcvtneoph2ps(vmm_odd, addr);
415 } else
416 assert(!"unsupported data type");
417 }
418
419 void operator()(const Operand &dst, const Operand &src) const {
420 if (dst.isMEM()) {
421 if (is_bf16_) {
422 constexpr bool isAvx2 = isa == avx2;
423 const typename std::conditional<isAvx2, Xmm, Ymm>::type
424 dst_reg {src.getIdx()};
425 const typename std::conditional<isAvx2, Ymm, Zmm>::type
426 src_reg {src.getIdx()};
427
428 // convert f32 output to bf16
429 if (!bf16_emu_)
430 h_->vcvtneps2bf16(dst_reg, src_reg,
431 mayiuse(avx512_core) ? Xbyak::EvexEncoding
432 : Xbyak::VexEncoding);
433 else
434 bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);
435
436 h_->uni_vmovups(dst.getAddress(), dst_reg);
437 } else if (is_f16_) {
438 auto src_reg = Vmm(src.getIdx());
439 h_->vcvtps2ph(dst.getAddress(), src_reg, h_->_op_mxcsr);
440 } else {
441 h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
442 }
443 } else {
444 if (is_bf16_) {
445 // convert bf16 input to f32
446 h_->vpmovzxwd(Vmm(dst.getIdx()), src.getAddress());
447 h_->vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10);
448 } else if (is_f16_) {
449 if (mayiuse(avx512_core_fp16))
450 h_->vcvtph2psx(Vmm(dst.getIdx()), src.getAddress());
451 else
452 h_->vcvtph2ps(Vmm(dst.getIdx()), src.getAddress());
453 } else {
454 h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
455 }
456 }
457 }
458
459private:
460 DNNL_DISALLOW_COPY_AND_ASSIGN(helper_vmovups_data_t);
461};
462
463template <cpu_isa_t isa>
464struct jit_bnorm_fwd_statistics_t : public jit_generator {
465 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_statistics_t)
466 using Vmm = typename cpu_isa_traits<isa>::Vmm;
467
468 const AddressFrame &vmmword
469 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
470
471 struct call_params_t {
472 size_t N, C, S;
473 const void *src;
474 const acc_data_t *mean;
475 const acc_data_t *var;
476 size_t blk_has_tail;
477 size_t do_normalise;
478 };
479
480 const Reg64 reg_param_ = abi_param1;
481 const Reg64 reg_tmp_ = abi_not_param1;
482 const Reg64 reg_N_ = rsi;
483 const Reg64 reg_S_ = rax;
484 const Reg64 reg_C_ = rdx;
485 const Reg64 reg_off_c_ = rbx;
486 const Reg64 reg_blk_has_tail_ = rbp;
487
488 const Reg64 reg_off_dat_ = r8;
489 const Reg64 reg_off_dat_save_ = r9;
490 const Reg64 reg_ptr_mean_ = r10;
491 const Reg64 reg_ptr_var_ = r11;
492 const Reg64 reg_ptr_src_ = r12;
493 const Reg64 reg_do_normalise_ = r13;
494 const Reg64 reg_ptr_stat_ = r14;
495
496 const Vmm v_ = Vmm(0);
497 const Vmm vtmp_ = Vmm(1);
498 const Vmm vtail_mask_ = Vmm(2);
499 const Vmm vNS_ = Vmm(3);
500 const Vmm vzero_ = Vmm(4);
501 const Vmm vsrc_aux = Vmm(2); //use for xf16 nspc on AVX2
502 // When variance is computed then two vmms(one for variance and
503 // one for mean) are needed to unroll one c block at any moment,
504 // therefore the number of registers which are used to unrolling
505 // must to be divisible by two.
506 static constexpr int min_idx_to_unroll_ = 4;
507 static constexpr int max_idx_to_unroll_ = isa == avx512_core ? 28 : 16;
508 static constexpr int number_of_vmms_to_unrolling_variables_
509 = max_idx_to_unroll_ - min_idx_to_unroll_;
510 static_assert(number_of_vmms_to_unrolling_variables_ % 2 == 0
511 && number_of_vmms_to_unrolling_variables_ != 0,
512 "Number of register to unrolling must to be divisible by 2.");
513
514 const Opmask ktail_mask_ = k2;
515
516 const batch_normalization_pd_t *pd_;
517 const jit_memory_tag_kind_t tag_kind_;
518 const int vlen;
519 const int simd_w;
520 const bool is_avx2_ne_xf16_;
521 jit_bnorm_process_tail_t<isa> jit_tail_;
522 helper_vmovups_data_t<isa> helper_vmovups_;
523 int stride_N_, stride_S_, stride_C_;
524 size_t data_type_size_, acc_type_size_;
525
526 void load_common_params() {
527#define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
528 mov(reg_ptr_src_, PARAM_PTR(src));
529 mov(reg_ptr_mean_, PARAM_PTR(mean));
530 mov(reg_ptr_var_, PARAM_PTR(var));
531#undef PARAM_PTR
532 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
533 mov(reg_do_normalise_, dword[PARAM_ADDR(do_normalise)]);
534 }
535
536 void zeroise() {
537 Label label_zeroise;
538 xor_(reg_off_c_, reg_off_c_);
539 uni_vpxor(vzero_, vzero_, vzero_);
540 mov(reg_C_, dword[PARAM_ADDR(C)]);
541 L(label_zeroise);
542 {
543 jit_tail_.uni_vmovups_maybe_tail(
544 vmmword[reg_ptr_stat_ + reg_off_c_], vzero_);
545 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
546 jit_tail_.uni_vmovups_maybe_tail(
547 vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], vzero_);
548 }
549 add(reg_off_c_, simd_w * acc_type_size_);
550 dec(reg_C_);
551 jnz(label_zeroise);
552 }
553 }
554
555 void load_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
556 int start_idx = min_idx_to_unroll_;
557 int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
558 const int step = simd_w * acc_type_size_;
559
560 // load mean or variance
561 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
562 const Vmm vstat = Vmm(idx);
563 jit_tail_.uni_vmovups_maybe_tail(
564 vstat, vmmword[reg_ptr_stat_ + reg_off_c_ + off]);
565 }
566
567 // if variance is counted then mean also is needed
568 if (!compute_mean) {
569 start_idx = min_idx_to_unroll_ + c_blks_to_unroll;
570 end_idx = min_idx_to_unroll_ + 2 * c_blks_to_unroll;
571
572 for (int idx = start_idx, off = 0; idx < end_idx;
573 idx++, off += step) {
574 const Vmm vmean = Vmm(idx);
575 jit_tail_.uni_vmovups_maybe_tail(
576 vmean, vmmword[reg_ptr_mean_ + reg_off_c_ + off]);
577 }
578 }
579 }
580
581 void compute_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
582 const int start_idx = min_idx_to_unroll_;
583 const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
584 const int step = simd_w * data_type_size_;
585
586 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
587 const Vmm vstat = Vmm(idx);
588
589 helper_vmovups_(v_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
590
591 if (compute_mean) {
592 uni_vaddps(vstat, vstat, v_);
593 } else {
594 const Vmm vmean = Vmm(idx + c_blks_to_unroll);
595
596 // var += (src - mean)^2
597 uni_vsubps(vtmp_, v_, vmean, vtmp_);
598 uni_vfmadd231ps(vstat, vtmp_, vtmp_);
599 }
600 }
601 }
602
603 void compute_stat_avx2_ne_xf16(
604 bool compute_mean, const int c_blks_to_unroll = 1) {
605 const int start_idx = min_idx_to_unroll_;
606 const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
607 const int step = simd_w * data_type_size_;
608
609 for (int idx = start_idx, off = 0; idx < end_idx;
610 idx += 2, off += 2 * step) {
611 const bool is_c_blks_tail = (end_idx - idx) < 2;
612 const Vmm vsrc_even = v_;
613 const Vmm vsrc_odd = vsrc_aux;
614 if (is_c_blks_tail)
615 helper_vmovups_(
616 vsrc_even, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
617 else
618 helper_vmovups_(vsrc_even, vsrc_odd,
619 vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
620 for (int i_odd = 0; i_odd < 2 && idx + i_odd < end_idx; ++i_odd) {
621 const Vmm vstat = Vmm(idx + i_odd);
622 const Vmm vsrc = i_odd ? vsrc_odd : vsrc_even;
623 if (compute_mean) {
624 uni_vaddps(vstat, vstat, vsrc);
625 } else {
626 const Vmm vmean = Vmm(idx + i_odd + c_blks_to_unroll);
627 uni_vsubps(vtmp_, vsrc, vmean, vtmp_);
628 uni_vfmadd231ps(vstat, vtmp_, vtmp_);
629 }
630 }
631 }
632 }
633
634 void store_stat(const int c_blks_to_unroll = 1) {
635 const int start_idx = min_idx_to_unroll_;
636 const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
637 const int step = simd_w * acc_type_size_;
638
639 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
640 const Vmm vstat = Vmm(idx);
641
642 jit_tail_.uni_vmovups_maybe_tail(
643 vmmword[reg_ptr_stat_ + reg_off_c_ + off], vstat);
644 }
645 }
646
647 void compute_blocked(bool compute_mean) {
648 Label label_C, label_S;
649 mov(reg_C_, dword[PARAM_ADDR(C)]);
650 L(label_C);
651 {
652 mov(reg_off_dat_, reg_off_dat_save_);
653
654 load_stat(compute_mean);
655
656 mov(reg_S_, dword[PARAM_ADDR(S)]);
657 L(label_S);
658 {
659 compute_stat(compute_mean);
660
661 add(reg_off_dat_, stride_S_ * data_type_size_);
662
663 dec(reg_S_);
664 jnz(label_S);
665 }
666
667 store_stat();
668
669 add(reg_off_dat_save_, stride_C_ * data_type_size_);
670 add(reg_off_c_, simd_w * acc_type_size_);
671
672 dec(reg_C_);
673 jnz(label_C);
674 }
675 }
676
677 void compute_nspc(bool compute_mean) {
678 mov(reg_C_, dword[PARAM_ADDR(C)]);
679
680 // When a variance is computed, two values are unrolled: mean and variance,
681 // so number_of_vmms_to_unrolling_variables_ is divided by 2.
682 const int max_of_unrolled_c_blks = compute_mean
683 ? number_of_vmms_to_unrolling_variables_
684 : number_of_vmms_to_unrolling_variables_ / 2;
685 std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
686
687 for (int c_blks_to_unroll = max_of_unrolled_c_blks;
688 c_blks_to_unroll > 0; --c_blks_to_unroll) {
689 L(c_unroll_label[c_blks_to_unroll]);
690 {
691 cmp(reg_C_, c_blks_to_unroll);
692 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
693
694 mov(reg_off_dat_, reg_off_dat_save_);
695
696 load_stat(compute_mean, c_blks_to_unroll);
697
698 Label label_S;
699 mov(reg_S_, dword[PARAM_ADDR(S)]);
700 L(label_S);
701 {
702 is_avx2_ne_xf16_
703 ? compute_stat_avx2_ne_xf16(
704 compute_mean, c_blks_to_unroll)
705 : compute_stat(compute_mean, c_blks_to_unroll);
706
707 add(reg_off_dat_, stride_S_ * data_type_size_);
708
709 dec(reg_S_);
710 jnz(label_S);
711 }
712
713 store_stat(c_blks_to_unroll);
714
715 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
716 add(reg_off_dat_save_,
717 c_blks_to_unroll * stride_C_ * data_type_size_);
718
719 sub(reg_C_, c_blks_to_unroll);
720 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
721 }
722 }
723 L(c_unroll_label[0]);
724 }
725
726 void compute(bool compute_mean) {
727 Label label_N;
728 mov(reg_N_, dword[PARAM_ADDR(N)]);
729 L(label_N);
730 {
731 xor_(reg_off_dat_save_, reg_off_dat_save_);
732 xor_(reg_off_c_, reg_off_c_);
733
734 tag_kind_ == jit_memory_tag_kind_t::nspc
735 ? compute_nspc(compute_mean)
736 : compute_blocked(compute_mean);
737
738 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
739 xor_(reg_off_dat_save_, reg_off_dat_save_);
740 xor_(reg_off_c_, reg_off_c_);
741 add(reg_off_dat_save_, vlen / 2);
742 add(reg_off_c_, vlen / 2);
743
744 compute_blocked(compute_mean);
745 }
746
747 add(reg_ptr_src_, stride_N_ * data_type_size_);
748 dec(reg_N_);
749 jnz(label_N);
750 }
751 }
752
753 void normalize() {
754 Label label_ret, label_normalise;
755 cmp(reg_do_normalise_, 0);
756 jz(label_ret);
757
758 const int S = pd_->D() * pd_->H() * pd_->W();
759 mov(reg_tmp_, float2int(pd_->MB() * S));
760 Xmm xtmp = Xmm(vtmp_.getIdx());
761 uni_vmovq(xtmp, reg_tmp_);
762 uni_vbroadcastss(vNS_, xtmp);
763
764 xor_(reg_off_c_, reg_off_c_);
765 mov(reg_C_, dword[PARAM_ADDR(C)]);
766 L(label_normalise);
767 {
768 jit_tail_.uni_vmovups_maybe_tail(
769 v_, vmmword[reg_ptr_stat_ + reg_off_c_]);
770 uni_vdivps(v_, v_, vNS_);
771 jit_tail_.uni_vmovups_maybe_tail(
772 vmmword[reg_ptr_stat_ + reg_off_c_], v_);
773
774 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
775 jit_tail_.uni_vmovups_maybe_tail(
776 v_, vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2]);
777 uni_vdivps(v_, v_, vNS_);
778 jit_tail_.uni_vmovups_maybe_tail(
779 vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], v_);
780 }
781
782 add(reg_off_c_, simd_w * acc_type_size_);
783 dec(reg_C_);
784 jnz(label_normalise);
785 }
786
787 L(label_ret);
788 }
789
790 jit_bnorm_fwd_statistics_t(const batch_normalization_pd_t *pd,
791 const jit_memory_tag_kind_t tag_kind)
792 : jit_generator(jit_name())
793 , pd_(pd)
794 , tag_kind_(tag_kind)
795 , vlen(get_vlen<isa>(tag_kind))
796 , simd_w(get_simd_w<isa>(tag_kind))
797 , is_avx2_ne_xf16_(is_avx2_ne_xf16<isa>(pd))
798 , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_,
799 ktail_mask_)
800 , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
801 static_assert(utils::one_of(isa, sse41, avx2, avx512_core),
802 "unsupported isa");
803
804 std::tie(stride_N_, stride_S_, stride_C_)
805 = get_data_strides<isa>(pd_, tag_kind);
806
807 data_type_size_ = types::data_type_size(pd->src_md()->data_type);
808 acc_type_size_ = sizeof(acc_data_t);
809 }
810};
811
812template <cpu_isa_t isa>
813struct jit_bnorm_fwd_mean_t : jit_bnorm_fwd_statistics_t<isa> {
814 using call_params_t =
815 typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
816
817 jit_bnorm_fwd_mean_t(const batch_normalization_pd_t *pd,
818 const jit_memory_tag_kind_t tag_kind)
819 : jit_bnorm_fwd_statistics_t<isa>(pd, tag_kind) {}
820
821 void generate() override {
822 this->preamble();
823 this->load_common_params();
824 this->mov(this->reg_ptr_stat_, this->reg_ptr_mean_);
825 this->jit_tail_.prepare_tail();
826 this->zeroise();
827 this->compute(true);
828 this->normalize();
829 this->postamble();
830 }
831};
832
833template <cpu_isa_t isa>
834struct jit_bnorm_fwd_var_t : jit_bnorm_fwd_statistics_t<isa> {
835 using call_params_t =
836 typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
837
838 jit_bnorm_fwd_var_t(const batch_normalization_pd_t *pd,
839 const jit_memory_tag_kind_t tag_kind)
840 : jit_bnorm_fwd_statistics_t<isa>(pd, tag_kind) {}
841
842 void generate() override {
843 this->preamble();
844 this->load_common_params();
845 this->mov(this->reg_ptr_stat_, this->reg_ptr_var_);
846 this->jit_tail_.prepare_tail();
847 this->zeroise();
848 this->compute(false);
849 this->normalize();
850 this->postamble();
851 }
852};
853
854template <cpu_isa_t isa>
855struct jit_bnorm_fwd_t : public jit_generator {
856 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_t)
857 using Vmm = typename cpu_isa_traits<isa>::Vmm;
858
859 const AddressFrame &vmmword
860 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
861
862 struct call_params_t {
863 size_t N, C, S;
864 const void *src;
865 void *dst;
866 const uint8_t *ws;
867 const acc_data_t *mean, *var;
868 const acc_data_t *scale, *shift;
869 size_t blk_has_tail;
870 };
871
872 const Reg64 reg_param_ = abi_param1;
873 const Reg64 reg_tmp_ = abi_not_param1;
874 const Reg64 reg_N_ = rsi;
875 const Reg64 reg_S_ = rax;
876 const Reg64 reg_C_ = rdx;
877 const Reg64 reg_off_c_ = rbx;
878 const Reg64 reg_blk_has_tail_ = rbp;
879
880 const Reg64 reg_off_dat_ = r8;
881 const Reg64 reg_off_dat_save_ = r9;
882 const Reg64 reg_ptr_ws_ = r10;
883 const Reg64 reg_ptr_scale_ = r11;
884 const Reg64 reg_ptr_shift_ = reg_N_;
885 const Reg64 reg_ptr_var_ = r12;
886 const Reg64 reg_ptr_mean_ = r13;
887 const Reg64 reg_ptr_dst_ = r14;
888 const Reg64 reg_ptr_src_ = r15;
889 const Reg64 reg_alpha_ = reg_ptr_ws_;
890
891 const Vmm vmask = Vmm(0); // required for avx2 and sse41 ReLU computation
892 const Vmm vone_ = Vmm(1);
893 const Vmm vmean_ = Vmm(2);
894 const Vmm vvar_ = Vmm(3);
895 const Vmm vsqrtvar_ = Vmm(4);
896 const Vmm vgamma_ = Vmm(5);
897 const Vmm vbeta_ = Vmm(6);
898 const Vmm veps_ = Vmm(7);
899 const Vmm vtmp_ = Vmm(8);
900 const Vmm v_ = Vmm(9);
901 const Vmm vzero_ = Vmm(10);
902 const Vmm vtail_mask_ = Vmm(11);
903 const Vmm valpha = Vmm(12);
904 const Vmm vsrc_aux = Vmm(13);
905 const Vmm vstore_mask_ = vtmp_;
906 const Vmm vmean_even_ = vmean_;
907 const Vmm vmean_odd_ = Vmm(14);
908 const Vmm vsqrtvar_even_ = vsqrtvar_;
909 const Vmm vsqrtvar_odd_ = Vmm(15);
910 const Vmm vvar_even_ = vvar_;
911 const Vmm vvar_odd_ = vsrc_aux;
912
913 const Opmask kstore_mask_ = k1;
914 const Opmask ktail_mask_ = k2;
915
916 const batch_normalization_pd_t *pd_;
917 const jit_memory_tag_kind_t tag_kind_;
918 const int vlen;
919 const int simd_w;
920 const bool is_avx2_ne_xf16_;
921 jit_bnorm_process_tail_t<isa> jit_tail_;
922 jit_bnorm_process_relu_t<isa> jit_relu_;
923 helper_vmovups_data_t<isa> helper_vmovups_;
924 int stride_N_, stride_S_, stride_C_;
925 size_t data_type_size_, acc_type_size_;
926
927 enum {
928 stack_off_N = 0,
929 stack_off_shift = 8,
930 stack_size_required = 16,
931 };
932
933 void load_common_params() {
934#define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
935 mov(reg_ptr_src_, PARAM_PTR(src));
936 mov(reg_ptr_dst_, PARAM_PTR(dst));
937 mov(reg_ptr_mean_, PARAM_PTR(mean));
938 mov(reg_ptr_var_, PARAM_PTR(var));
939 mov(reg_ptr_scale_, PARAM_PTR(scale));
940 if (jit_relu_.with_relu_ && !jit_relu_.with_relu_inf_only_)
941 mov(reg_ptr_ws_, PARAM_PTR(ws));
942
943 Xmm x = Xmm(v_.getIdx());
944
945 mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon));
946 uni_vmovq(x, reg_tmp_);
947 uni_vbroadcastss(veps_, x);
948
949 mov(reg_tmp_, float2int(1.f));
950 uni_vmovq(x, reg_tmp_);
951 uni_vbroadcastss(vone_, x);
952
953 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
954
955 mov(reg_tmp_, PARAM_PTR(shift));
956 mov(ptr[rsp + stack_off_shift], reg_tmp_);
957 mov(reg_tmp_, PARAM_PTR(N));
958 mov(ptr[rsp + stack_off_N], reg_tmp_);
959#undef PARAM_PTR
960 }
961
962 void load_c_specifics(
963 const bool has_load_mean_sqrtvar, const int offt = 0) {
964 if (!has_load_mean_sqrtvar) {
965 jit_tail_.uni_vmovups_maybe_tail(
966 vmean_, vmmword[reg_ptr_mean_ + reg_off_c_ + offt]);
967 jit_tail_.uni_vmovups_maybe_tail(
968 vvar_, vmmword[reg_ptr_var_ + reg_off_c_ + offt]);
969
970 uni_vmovups(vsqrtvar_, vvar_);
971 uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
972 uni_vsqrtps(vsqrtvar_, vsqrtvar_);
973 if (isa == sse41) {
974 movups(vtmp_, vone_);
975 divps(vtmp_, vsqrtvar_);
976 movups(vsqrtvar_, vtmp_);
977 } else
978 vdivps(vsqrtvar_, vone_, vsqrtvar_);
979 }
980
981 if (pd_->use_scale())
982 jit_tail_.uni_vmovups_maybe_tail(
983 vgamma_, vmmword[reg_ptr_scale_ + reg_off_c_ + offt]);
984 if (pd_->use_shift())
985 jit_tail_.uni_vmovups_maybe_tail(
986 vbeta_, vmmword[reg_ptr_shift_ + reg_off_c_ + offt]);
987 }
988
989 void compute_bnorm(const Vmm &v, const Vmm &vmean, const Vmm &vsqrtvar,
990 bool stream_store_allowed, bool has_load_src, const int offt = 0) {
991 if (!has_load_src)
992 helper_vmovups_(v, vmmword[reg_ptr_src_ + reg_off_dat_ + offt]);
993 uni_vsubps(v, v, vmean);
994 uni_vmulps(v, v, vsqrtvar);
995
996 if (pd_->use_scale() && pd_->use_shift())
997 uni_vfmadd213ps(v, vgamma_, vbeta_);
998 else if (pd_->use_scale())
999 uni_vmulps(v, v, vgamma_);
1000 else if (pd_->use_shift())
1001 uni_vaddps(v, v, vbeta_);
1002
1003 jit_relu_.fwd_process_relu(v);
1004
1005 if (stream_store_allowed) {
1006 uni_vmovntps(vmmword[reg_ptr_dst_ + reg_off_dat_ + offt], v);
1007 } else {
1008 helper_vmovups_(vmmword[reg_ptr_dst_ + reg_off_dat_ + offt], v);
1009 }
1010 }
1011
1012 void load_two_c_mean_sqrtvar() {
1013 const int offt = simd_w * acc_type_size_;
1014 jit_tail_.uni_vmovups_maybe_tail(
1015 vmean_even_, vmmword[reg_ptr_mean_ + reg_off_c_]);
1016 jit_tail_.uni_vmovups_maybe_tail(
1017 vmean_odd_, vmmword[reg_ptr_mean_ + reg_off_c_ + offt]);
1018 jit_tail_.uni_vmovups_maybe_tail(
1019 vvar_even_, vmmword[reg_ptr_var_ + reg_off_c_]);
1020 jit_tail_.uni_vmovups_maybe_tail(
1021 vvar_odd_, vmmword[reg_ptr_var_ + reg_off_c_ + offt]);
1022
1023 // merge mean and variance in interleave to plain layout when needed
1024 if (!pd_->stats_is_src()) {
1025 helper_vmovups_.merge_interleaved_to_plain(
1026 vmean_even_, vmean_odd_, vtmp_);
1027 helper_vmovups_.merge_interleaved_to_plain(
1028 vvar_even_, vvar_odd_, vtmp_);
1029 }
1030 uni_vmovups(vsqrtvar_even_, vvar_even_);
1031 uni_vaddps(vsqrtvar_even_, vsqrtvar_even_, veps_);
1032 uni_vsqrtps(vsqrtvar_even_, vsqrtvar_even_);
1033 vdivps(vsqrtvar_even_, vone_, vsqrtvar_even_);
1034
1035 uni_vmovups(vsqrtvar_odd_, vvar_odd_);
1036 uni_vaddps(vsqrtvar_odd_, vsqrtvar_odd_, veps_);
1037 uni_vsqrtps(vsqrtvar_odd_, vsqrtvar_odd_);
1038 vdivps(vsqrtvar_odd_, vone_, vsqrtvar_odd_);
1039 }
1040
1041 void compute_bnorm_avx2_ne_xf16(
1042 const bool is_c_blks_tail, bool stream_store_allowed) {
1043 const Vmm vsrc_even = v_;
1044 const Vmm vsrc_odd = vsrc_aux;
1045 if (is_c_blks_tail) {
1046 compute_bnorm(
1047 vsrc_even, vmean_, vsqrtvar_, stream_store_allowed, false);
1048 } else {
1049 helper_vmovups_(
1050 vsrc_even, vsrc_odd, vmmword[reg_ptr_src_ + reg_off_dat_]);
1051 helper_vmovups_.merge_interleaved_to_plain(
1052 vsrc_even, vsrc_odd, vtmp_);
1053 load_c_specifics(true);
1054 compute_bnorm(vsrc_even, vmean_even_, vsqrtvar_even_,
1055 stream_store_allowed, true);
1056
1057 load_c_specifics(true, simd_w * acc_type_size_);
1058 compute_bnorm(vsrc_odd, vmean_odd_, vsqrtvar_odd_,
1059 stream_store_allowed, true, stride_C_ * data_type_size_);
1060 }
1061 }
1062
1063 void compute_avx2_ne_xf16(bool stream_store_allowed) {
1064 Label label_C, label_S, label_C_tail, label_C_end, label_S_C_tail;
1065 mov(reg_C_, dword[PARAM_ADDR(C)]);
1066 L(label_C);
1067 {
1068 cmp(reg_C_, 1);
1069 jle(label_C_tail, T_NEAR);
1070
1071 mov(reg_off_dat_, reg_off_dat_save_);
1072 load_two_c_mean_sqrtvar();
1073
1074 mov(reg_S_, dword[PARAM_ADDR(S)]);
1075 L(label_S);
1076 {
1077 compute_bnorm_avx2_ne_xf16(false, stream_store_allowed);
1078
1079 add(reg_off_dat_, stride_S_ * data_type_size_);
1080 dec(reg_S_);
1081 jnz(label_S, T_NEAR);
1082 }
1083 add(reg_off_dat_save_, 2 * stride_C_ * data_type_size_);
1084 add(reg_off_c_, 2 * simd_w * acc_type_size_);
1085
1086 sub(reg_C_, 2);
1087 jnz(label_C, T_NEAR);
1088 }
1089
1090 L(label_C_tail);
1091 {
1092 cmp(reg_C_, 0);
1093 jz(label_C_end, T_NEAR);
1094
1095 mov(reg_off_dat_, reg_off_dat_save_);
1096 load_c_specifics(false);
1097
1098 mov(reg_S_, dword[PARAM_ADDR(S)]);
1099 L(label_S_C_tail);
1100 {
1101 compute_bnorm_avx2_ne_xf16(true, stream_store_allowed);
1102
1103 add(reg_off_dat_, stride_S_ * data_type_size_);
1104 dec(reg_S_);
1105 jnz(label_S_C_tail, T_NEAR);
1106 }
1107 }
1108 L(label_C_end);
1109 }
1110
1111 void compute_blocked(bool stream_store_allowed) {
1112 Label label_C, label_S;
1113 mov(reg_C_, dword[PARAM_ADDR(C)]);
1114 L(label_C);
1115 {
1116 mov(reg_off_dat_, reg_off_dat_save_);
1117
1118 load_c_specifics(false);
1119
1120 mov(reg_S_, dword[PARAM_ADDR(S)]);
1121 L(label_S);
1122 {
1123 compute_bnorm(
1124 v_, vmean_, vsqrtvar_, stream_store_allowed, false);
1125
1126 add(reg_off_dat_, stride_S_ * data_type_size_);
1127
1128 dec(reg_S_);
1129 jnz(label_S);
1130 }
1131
1132 add(reg_off_dat_save_, stride_C_ * data_type_size_);
1133 add(reg_off_c_, simd_w * acc_type_size_);
1134
1135 dec(reg_C_);
1136 jnz(label_C);
1137 }
1138 }
1139
1140 void compute(bool stream_store_allowed) {
1141 Label label_N;
1142 mov(reg_N_, ptr[rsp + stack_off_N]);
1143 L(label_N);
1144 {
1145 // save reg_N_, because register is shared with reg_ptr_shift_
1146 mov(ptr[rsp + stack_off_N], reg_N_);
1147 mov(reg_ptr_shift_, ptr[rsp + stack_off_shift]);
1148
1149 xor_(reg_off_dat_save_, reg_off_dat_save_);
1150 xor_(reg_off_c_, reg_off_c_);
1151
1152 is_avx2_ne_xf16_ ? compute_avx2_ne_xf16(stream_store_allowed)
1153 : compute_blocked(stream_store_allowed);
1154
1155 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1156 xor_(reg_off_dat_save_, reg_off_dat_save_);
1157 xor_(reg_off_c_, reg_off_c_);
1158 add(reg_off_dat_save_, vlen / 2);
1159 add(reg_off_c_, vlen / 2);
1160
1161 compute_blocked(stream_store_allowed);
1162 }
1163
1164 add(reg_ptr_src_, stride_N_ * data_type_size_);
1165 add(reg_ptr_dst_, stride_N_ * data_type_size_);
1166 if (jit_relu_.with_relu_ && !jit_relu_.with_relu_inf_only_)
1167 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1168
1169 // restore reg_N_, because register is shared with reg_ptr_shift_
1170 mov(reg_N_, ptr[rsp + stack_off_N]);
1171 dec(reg_N_);
1172 jnz(label_N);
1173 }
1174 }
1175
1176 jit_bnorm_fwd_t(const batch_normalization_pd_t *pd,
1177 const jit_memory_tag_kind_t tag_kind)
1178 : jit_generator(jit_name())
1179 , pd_(pd)
1180 , tag_kind_(tag_kind)
1181 , vlen(get_vlen<isa>(tag_kind))
1182 , simd_w(get_simd_w<isa>(tag_kind))
1183 , is_avx2_ne_xf16_(is_avx2_ne_xf16<isa>(pd))
1184 , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_,
1185 ktail_mask_)
1186 , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1187 vstore_mask_, kstore_mask_, valpha, vmask, reg_alpha_)
1188 , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1189 static_assert(utils::one_of(isa, sse41, avx2, avx512_core),
1190 "unsupported isa");
1191
1192 std::tie(stride_N_, stride_S_, stride_C_)
1193 = get_data_strides<isa>(pd_, tag_kind);
1194
1195 data_type_size_ = types::data_type_size(pd->src_md()->data_type);
1196 acc_type_size_ = sizeof(acc_data_t);
1197 }
1198
1199 void generate() override {
1200 bool is_xf16 = utils::one_of(
1201 pd_->src_md()->data_type, data_type::bf16, data_type::f16);
1202 const bool is_tail_in_nspc_format
1203 = tag_kind_ == jit_memory_tag_kind_t::nspc
1204 && jit_tail_.tail_ != 0;
1205 const bool stream_store_allowed = !is_xf16 && !is_tail_in_nspc_format;
1206
1207 preamble();
1208 if (helper_vmovups_.bf16_emu_)
1209 helper_vmovups_.bf16_emu_->init_vcvtneps2bf16();
1210 sub(rsp, stack_size_required);
1211 load_common_params();
1212 jit_relu_.fwd_prepare_relu();
1213 jit_tail_.prepare_tail();
1214
1215 Label normal_store, end_store;
1216 test(reg_ptr_dst_, vlen - 1);
1217 jnz(normal_store, T_NEAR);
1218 compute(stream_store_allowed);
1219 jmp(end_store, T_NEAR);
1220 L(normal_store);
1221 { compute(false); }
1222 L(end_store);
1223
1224 add(rsp, stack_size_required);
1225 postamble();
1226 }
1227};
1228
1229template <cpu_isa_t isa>
1230struct jit_bnorm_bwd_t : public jit_generator {
1231 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_t)
1232 using Vmm = typename cpu_isa_traits<isa>::Vmm;
1233
1234 const AddressFrame &vmmword
1235 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1236
1237 struct call_params_t {
1238 size_t N, C, S;
1239 const void *src, *diff_src, *diff_dst;
1240 const uint8_t *ws;
1241 const acc_data_t *mean, *var;
1242 const acc_data_t *scale, *diff_scale, *diff_shift;
1243 size_t blk_has_tail;
1244 };
1245
1246 const Reg64 reg_param_ = abi_param1;
1247 const Reg64 reg_tmp_ = abi_not_param1;
1248 const Reg64 reg_N_ = rsi;
1249 const Reg64 reg_S_ = rax;
1250 const Reg64 reg_C_ = rdx;
1251 const Reg64 reg_off_c_ = rbx;
1252 const Reg64 reg_blk_has_tail_ = rbp;
1253
1254 const Reg64 reg_off_dat_ = r8;
1255 const Reg64 reg_off_dat_save_ = r9;
1256 const Reg64 reg_ptr_c_ = r10;
1257 const Reg64 reg_ptr_ws_ = r11;
1258 const Reg64 reg_ptr_diff_dst_ = r12;
1259 const Reg64 reg_ptr_diff_src_ = r13;
1260 const Reg64 reg_ptr_src_ = r14;
1261
1262 const Vmm vzero_ = Vmm(0);
1263 const Vmm vone_ = Vmm(1);
1264 const Vmm vmean_ = Vmm(2);
1265 const Vmm vsqrtvar_ = Vmm(3);
1266 const Vmm vgamma_ = Vmm(4);
1267 const Vmm vdiff_gamma_ = Vmm(5);
1268 const Vmm vdiff_beta_ = Vmm(6);
1269 const Vmm veps_ = Vmm(7);
1270 const Vmm vNS_ = Vmm(8);
1271 const Vmm vtmp_ = Vmm(9);
1272 const Vmm v_ = Vmm(10);
1273 const Vmm vtail_mask_ = Vmm(11);
1274 const Vmm vstore_mask_ = vtmp_;
1275
1276 const Opmask kstore_mask_ = k1;
1277 const Opmask ktail_mask_ = k2;
1278
1279 const batch_normalization_pd_t *pd_;
1280 const jit_memory_tag_kind_t tag_kind_;
1281 const int vlen;
1282 const int simd_w;
1283 jit_bnorm_process_tail_t<isa> jit_tail_;
1284 jit_bnorm_process_relu_t<isa> jit_relu_;
1285 helper_vmovups_data_t<isa> helper_vmovups_;
1286 int stride_N_, stride_S_, stride_C_;
1287 size_t data_type_size_, acc_type_size_;
1288
1289 void load_common_params() {
1290#define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1291 mov(reg_ptr_src_, PARAM_PTR(src));
1292 mov(reg_ptr_diff_src_, PARAM_PTR(diff_src));
1293 mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1294 mov(reg_ptr_ws_, PARAM_PTR(ws));
1295#undef PARAM_PTR
1296
1297 Xmm x = Xmm(v_.getIdx());
1298
1299 mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon));
1300 uni_vmovq(x, reg_tmp_);
1301 uni_vbroadcastss(veps_, x);
1302
1303 mov(reg_tmp_, float2int(1.f));
1304 uni_vmovq(x, reg_tmp_);
1305 uni_vbroadcastss(vone_, x);
1306
1307 const int S = pd_->D() * pd_->H() * pd_->W();
1308 mov(reg_tmp_, float2int(pd_->MB() * S));
1309 uni_vmovq(x, reg_tmp_);
1310 uni_vbroadcastss(vNS_, x);
1311
1312 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1313 }
1314
1315 void load_c_specifics() {
1316 mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1317 jit_tail_.uni_vmovups_maybe_tail(
1318 vmean_, vmmword[reg_ptr_c_ + reg_off_c_]);
1319
1320 mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1321 jit_tail_.uni_vmovups_maybe_tail(
1322 vsqrtvar_, vmmword[reg_ptr_c_ + reg_off_c_]);
1323 uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
1324 uni_vsqrtps(vsqrtvar_, vsqrtvar_);
1325
1326 if (isa == sse41) {
1327 movups(vtmp_, vone_);
1328 divps(vtmp_, vsqrtvar_);
1329 movups(vsqrtvar_, vtmp_);
1330 } else
1331 vdivps(vsqrtvar_, vone_, vsqrtvar_);
1332
1333 if (pd_->use_scale()) {
1334 mov(reg_ptr_c_, ptr[PARAM_ADDR(scale)]);
1335 jit_tail_.uni_vmovups_maybe_tail(
1336 vgamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1337 }
1338
1339 if (calculate_diff_stats()) {
1340 mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_scale)]);
1341 jit_tail_.uni_vmovups_maybe_tail(
1342 vdiff_gamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1343 uni_vmulps(vdiff_gamma_, vdiff_gamma_, vsqrtvar_);
1344 uni_vdivps(vdiff_gamma_, vdiff_gamma_, vNS_);
1345 mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_shift)]);
1346 jit_tail_.uni_vmovups_maybe_tail(
1347 vdiff_beta_, vmmword[reg_ptr_c_ + reg_off_c_]);
1348 uni_vdivps(vdiff_beta_, vdiff_beta_, vNS_);
1349 }
1350 }
1351
1352 void compute_bnorm(bool stream_store_allowed) {
1353 helper_vmovups_(v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_]);
1354 jit_relu_.bwd_process_relu(v_);
1355
1356 if (calculate_diff_stats()) {
1357 uni_vsubps(v_, v_, vdiff_beta_);
1358 helper_vmovups_(vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_]);
1359 uni_vsubps(vtmp_, vtmp_, vmean_);
1360 uni_vmulps(vtmp_, vtmp_, vdiff_gamma_);
1361 uni_vsubps(v_, v_, vtmp_);
1362 }
1363
1364 if (pd_->use_scale()) uni_vmulps(v_, v_, vgamma_);
1365 uni_vmulps(v_, v_, vsqrtvar_);
1366
1367 if (stream_store_allowed) {
1368 uni_vmovntps(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1369 } else {
1370 helper_vmovups_(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1371 }
1372 }
1373
1374 void compute_blocked(bool stream_store_allowed) {
1375 Label label_C, label_S;
1376 mov(reg_C_, dword[PARAM_ADDR(C)]);
1377 L(label_C);
1378 {
1379 mov(reg_off_dat_, reg_off_dat_save_);
1380
1381 load_c_specifics();
1382
1383 mov(reg_S_, dword[PARAM_ADDR(S)]);
1384 L(label_S);
1385 {
1386 compute_bnorm(stream_store_allowed);
1387
1388 add(reg_off_dat_, stride_S_ * data_type_size_);
1389
1390 dec(reg_S_);
1391 jnz(label_S);
1392 }
1393
1394 add(reg_off_dat_save_, stride_C_ * data_type_size_);
1395 add(reg_off_c_, simd_w * acc_type_size_);
1396
1397 dec(reg_C_);
1398 jnz(label_C);
1399 }
1400 }
1401
1402 void compute_nspc(bool stream_store_allowed) {
1403 Label label_C, label_S;
1404 mov(reg_S_, dword[PARAM_ADDR(S)]);
1405 L(label_S);
1406 {
1407 mov(reg_off_dat_, reg_off_dat_save_);
1408 xor_(reg_off_c_, reg_off_c_);
1409
1410 mov(reg_C_, dword[PARAM_ADDR(C)]);
1411 L(label_C);
1412 {
1413 load_c_specifics();
1414
1415 compute_bnorm(stream_store_allowed);
1416
1417 add(reg_off_c_, simd_w * acc_type_size_);
1418 add(reg_off_dat_, stride_C_ * data_type_size_);
1419
1420 dec(reg_C_);
1421 jnz(label_C);
1422 }
1423
1424 add(reg_off_dat_save_, stride_S_ * data_type_size_);
1425
1426 dec(reg_S_);
1427 jnz(label_S);
1428 }
1429 }
1430
1431 void compute(bool stream_store_allowed) {
1432 Label label_N;
1433 mov(reg_N_, dword[PARAM_ADDR(N)]);
1434 L(label_N);
1435 {
1436 xor_(reg_off_dat_save_, reg_off_dat_save_);
1437 xor_(reg_off_c_, reg_off_c_);
1438
1439 tag_kind_ == jit_memory_tag_kind_t::nspc
1440 ? compute_nspc(stream_store_allowed)
1441 : compute_blocked(stream_store_allowed);
1442
1443 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1444 xor_(reg_off_dat_save_, reg_off_dat_save_);
1445 xor_(reg_off_c_, reg_off_c_);
1446 add(reg_off_dat_save_, vlen / 2);
1447 add(reg_off_c_, vlen / 2);
1448
1449 compute_blocked(stream_store_allowed);
1450 }
1451
1452 add(reg_ptr_src_, stride_N_ * data_type_size_);
1453 add(reg_ptr_diff_src_, stride_N_ * data_type_size_);
1454 add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1455 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1456
1457 dec(reg_N_);
1458 jnz(label_N);
1459 }
1460 }
1461
1462 bool calculate_diff_stats() const { return !pd_->use_global_stats(); }
1463
1464 jit_bnorm_bwd_t(const batch_normalization_pd_t *pd,
1465 const jit_memory_tag_kind_t tag_kind)
1466 : jit_generator(jit_name())
1467 , pd_(pd)
1468 , tag_kind_(tag_kind)
1469 , vlen(get_vlen<isa>(tag_kind))
1470 , simd_w(get_simd_w<isa>(tag_kind))
1471 , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_,
1472 ktail_mask_)
1473 , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1474 vstore_mask_, kstore_mask_)
1475 , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1476 static_assert(utils::one_of(isa, sse41, avx2, avx512_core),
1477 "unsupported isa");
1478
1479 std::tie(stride_N_, stride_S_, stride_C_)
1480 = get_data_strides<isa>(pd_, tag_kind);
1481
1482 data_type_size_ = types::data_type_size(pd->src_md()->data_type);
1483 acc_type_size_ = sizeof(acc_data_t);
1484 }
1485
1486 void generate() override {
1487 bool is_bf16 = pd_->src_md()->data_type == data_type::bf16;
1488 bool is_f16 = pd_->src_md()->data_type == data_type::f16;
1489 const bool is_tail_in_nspc_format
1490 = tag_kind_ == jit_memory_tag_kind_t::nspc
1491 && jit_tail_.tail_ != 0;
1492 const bool stream_store_allowed
1493 = !is_bf16 && !is_f16 && !is_tail_in_nspc_format;
1494
1495 preamble();
1496 if (helper_vmovups_.bf16_emu_)
1497 helper_vmovups_.bf16_emu_->init_vcvtneps2bf16();
1498 load_common_params();
1499 jit_relu_.bwd_prepare_relu();
1500 jit_tail_.prepare_tail();
1501
1502 Label normal_store, end_store;
1503 test(reg_ptr_diff_src_, vlen - 1);
1504 jnz(normal_store, T_NEAR);
1505 compute(stream_store_allowed);
1506 jmp(end_store, T_NEAR);
1507 L(normal_store);
1508 { compute(false); }
1509 L(end_store);
1510
1511 postamble();
1512 }
1513};
1514
1515template <cpu_isa_t isa>
1516struct jit_bnorm_bwd_diff_ss_t : public jit_generator {
1517 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_diff_ss_t)
1518 using Vmm = typename cpu_isa_traits<isa>::Vmm;
1519
1520 const AddressFrame &vmmword
1521 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1522
1523 struct call_params_t {
1524 size_t N, C, S;
1525 const void *src, *diff_dst;
1526 const uint8_t *ws;
1527 const acc_data_t *mean, *var;
1528 const acc_data_t *diff_gamma, *diff_beta;
1529 size_t blk_has_tail;
1530 };
1531
1532 const Reg64 reg_param_ = abi_param1;
1533 const Reg64 reg_tmp_ = abi_not_param1;
1534 const Reg64 reg_N_ = rsi;
1535 const Reg64 reg_S_ = rax;
1536 const Reg64 reg_C_ = rdx;
1537 const Reg64 reg_off_c_ = rbx;
1538 const Reg64 reg_blk_has_tail_ = rbp;
1539
1540 const Reg64 reg_off_dat_ = r8;
1541 const Reg64 reg_off_dat_save_ = r9;
1542 const Reg64 reg_ptr_c_ = r10;
1543 const Reg64 reg_ptr_diff_gamma_ = r11;
1544 const Reg64 reg_ptr_diff_beta_ = r12;
1545 const Reg64 reg_ptr_ws_ = r13;
1546 const Reg64 reg_ptr_diff_dst_ = r14;
1547 const Reg64 reg_ptr_src_ = r15;
1548
1549 const Vmm vtail_mask_ = Vmm(0);
1550 const Vmm v_ = Vmm(1);
1551 const Vmm vtmp_ = Vmm(2);
1552 const Vmm vstore_mask_ = vtmp_;
1553 const Vmm vzero_ = Vmm(3);
1554 const Vmm veps_ = Vmm(4);
1555 const Vmm vone_ = Vmm(5);
1556 // Diff_beta, diff_gamma and one of the statistic values(mean or sqrtvar)
1557 // are unrolled i.e.three vmms are needed to unroll one c block at any moment,
1558 // therefore the number of registers which are used to unrolling must to be
1559 // divisible by three.
1560 static constexpr int min_idx_to_unroll_ = 6;
1561 static constexpr int max_idx_to_unroll_ = isa == avx512_core ? 27 : 15;
1562 static constexpr int number_of_unrolled_variables_ = 3;
1563 static constexpr int number_of_vmms_to_unrolling_variables_
1564 = max_idx_to_unroll_ - min_idx_to_unroll_;
1565 static_assert(number_of_vmms_to_unrolling_variables_
1566 % number_of_unrolled_variables_
1567 == 0
1568 && number_of_vmms_to_unrolling_variables_ != 0,
1569 "Number of register to unrolling must to be divisible by 3.");
1570
1571 const Opmask kstore_mask_ = k1;
1572 const Opmask ktail_mask_ = k2;
1573
1574 const batch_normalization_pd_t *pd_;
1575 const jit_memory_tag_kind_t tag_kind_;
1576 const int vlen;
1577 const int simd_w;
1578 jit_bnorm_process_tail_t<isa> jit_tail_;
1579 jit_bnorm_process_relu_t<isa> jit_relu_;
1580 helper_vmovups_data_t<isa> helper_vmovups_;
1581 int stride_N_, stride_S_, stride_C_;
1582 size_t data_type_size_, acc_type_size_;
1583
1584 void load_common_params() {
1585#define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1586 mov(reg_ptr_src_, PARAM_PTR(src));
1587 mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1588 mov(reg_ptr_ws_, PARAM_PTR(ws));
1589 mov(reg_ptr_diff_gamma_, PARAM_PTR(diff_gamma));
1590 mov(reg_ptr_diff_beta_, PARAM_PTR(diff_beta));
1591#undef PARAM_PTR
1592
1593 Xmm x = Xmm(v_.getIdx());
1594
1595 mov(reg_tmp_, float2int(pd_->desc()->batch_norm_epsilon));
1596 uni_vmovq(x, reg_tmp_);
1597 uni_vbroadcastss(veps_, x);
1598
1599 mov(reg_tmp_, float2int(1.f));
1600 uni_vmovq(x, reg_tmp_);
1601 uni_vbroadcastss(vone_, x);
1602
1603 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1604 }
1605
1606 void zeroise() {
1607 Label label_zeroise;
1608 xor_(reg_off_c_, reg_off_c_);
1609 uni_vpxor(vzero_, vzero_, vzero_);
1610 mov(reg_C_, dword[PARAM_ADDR(C)]);
1611 L(label_zeroise);
1612 {
1613 jit_tail_.uni_vmovups_maybe_tail(
1614 vmmword[reg_ptr_diff_gamma_ + reg_off_c_], vzero_);
1615 jit_tail_.uni_vmovups_maybe_tail(
1616 vmmword[reg_ptr_diff_beta_ + reg_off_c_], vzero_);
1617 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1618 jit_tail_.uni_vmovups_maybe_tail(
1619 vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + vlen / 2],
1620 vzero_);
1621 jit_tail_.uni_vmovups_maybe_tail(
1622 vmmword[reg_ptr_diff_beta_ + reg_off_c_ + vlen / 2],
1623 vzero_);
1624 }
1625 add(reg_off_c_, simd_w * acc_type_size_);
1626 dec(reg_C_);
1627 jnz(label_zeroise);
1628 }
1629 }
1630
1631 void load_mean(const int c_blks_to_unroll = 1) {
1632 mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1633
1634 const int start_idx = min_idx_to_unroll_;
1635 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1636 + min_idx_to_unroll_;
1637 const int step = simd_w * acc_type_size_;
1638
1639 for (int idx = start_idx, off = 0; idx < end_idx;
1640 idx += number_of_unrolled_variables_, off += step) {
1641 const Vmm vmean = Vmm(idx);
1642
1643 jit_tail_.uni_vmovups_maybe_tail(
1644 vmean, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1645 }
1646 }
1647
1648 void zeroise_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1649 const int start_idx = min_idx_to_unroll_;
1650 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1651 + min_idx_to_unroll_;
1652
1653 for (int idx = start_idx; idx < end_idx;
1654 idx += number_of_unrolled_variables_) {
1655 const Vmm vdiff_beta = Vmm(idx + 1);
1656 const Vmm vdiff_gamma = Vmm(idx + 2);
1657
1658 uni_vpxor(vdiff_beta, vdiff_beta, vdiff_beta);
1659 uni_vpxor(vdiff_gamma, vdiff_gamma, vdiff_gamma);
1660 }
1661 }
1662
1663 void load_and_prepare_sqrtvar(const int c_blks_to_unroll = 1) {
1664 mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1665
1666 const int start_idx = min_idx_to_unroll_;
1667 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1668 + min_idx_to_unroll_;
1669 const int step = simd_w * acc_type_size_;
1670
1671 for (int idx = start_idx, off = 0; idx < end_idx;
1672 idx += number_of_unrolled_variables_, off += step) {
1673 const Vmm vsqrtvar = Vmm(idx);
1674
1675 jit_tail_.uni_vmovups_maybe_tail(
1676 vsqrtvar, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1677
1678 // 1.0 / sqrt(var + eps)
1679 uni_vaddps(vsqrtvar, vsqrtvar, veps_);
1680 uni_vsqrtps(vsqrtvar, vsqrtvar);
1681
1682 if (isa == sse41) {
1683 movups(vtmp_, vone_);
1684 divps(vtmp_, vsqrtvar);
1685 movups(vsqrtvar, vtmp_);
1686 } else
1687 vdivps(vsqrtvar, vone_, vsqrtvar);
1688 }
1689 }
1690
1691 void compute_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1692 const int start_idx = min_idx_to_unroll_;
1693 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1694 + min_idx_to_unroll_;
1695 const int step = simd_w * data_type_size_;
1696
1697 for (int idx = start_idx, off = 0; idx < end_idx;
1698 idx += number_of_unrolled_variables_, off += step) {
1699 const Vmm vmean = Vmm(idx);
1700 const Vmm vdiff_beta = Vmm(idx + 1);
1701 const Vmm vdiff_gamma = Vmm(idx + 2);
1702
1703 helper_vmovups_(
1704 v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_ + off]);
1705
1706 jit_relu_.bwd_process_relu(
1707 v_, off / (bits_per_byte * data_type_size_));
1708
1709 // diff_beta
1710 uni_vaddps(vdiff_beta, vdiff_beta, v_);
1711
1712 helper_vmovups_(vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
1713
1714 // diff_gamma, note that diff_gamma will be multiplied
1715 // by sqrtvar before store
1716 uni_vsubps(vtmp_, vtmp_, vmean);
1717 uni_vfmadd231ps(vdiff_gamma, vtmp_, v_);
1718 }
1719 }
1720
1721 void store_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1722 const int start_idx = min_idx_to_unroll_;
1723 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1724 + min_idx_to_unroll_;
1725 const int step = simd_w * acc_type_size_;
1726
1727 for (int idx = start_idx, off = 0; idx < end_idx;
1728 idx += number_of_unrolled_variables_, off += step) {
1729 const Vmm vdiff_beta = Vmm(idx + 1);
1730
1731 jit_tail_.uni_vmovups_maybe_tail(
1732 vtmp_, vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off]);
1733 uni_vaddps(vdiff_beta, vdiff_beta, vtmp_);
1734 jit_tail_.uni_vmovups_maybe_tail(
1735 vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off], vdiff_beta);
1736 }
1737
1738 for (int idx = start_idx, off = 0; idx < end_idx;
1739 idx += number_of_unrolled_variables_, off += step) {
1740 const Vmm vsqrtvar = Vmm(idx);
1741 const Vmm vdiff_gamma = Vmm(idx + 2);
1742
1743 // multiply diff_gamma by 1.0/sqrt(var + eps)
1744 uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
1745
1746 jit_tail_.uni_vmovups_maybe_tail(
1747 vtmp_, vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off]);
1748 uni_vaddps(vdiff_gamma, vdiff_gamma, vtmp_);
1749 jit_tail_.uni_vmovups_maybe_tail(
1750 vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off],
1751 vdiff_gamma);
1752 }
1753 }
1754
1755 void compute_blocked() {
1756 Label label_C, label_S;
1757 mov(reg_C_, dword[PARAM_ADDR(C)]);
1758 L(label_C);
1759 {
1760 mov(reg_off_dat_, reg_off_dat_save_);
1761
1762 load_mean();
1763 zeroise_diff_beta_and_diff_gamma();
1764
1765 mov(reg_S_, dword[PARAM_ADDR(S)]);
1766 L(label_S);
1767 {
1768 compute_diff_beta_and_diff_gamma();
1769
1770 add(reg_off_dat_, stride_S_ * data_type_size_);
1771
1772 dec(reg_S_);
1773 jnz(label_S);
1774 }
1775
1776 load_and_prepare_sqrtvar();
1777 store_diff_beta_and_diff_gamma();
1778
1779 add(reg_off_dat_save_, stride_C_ * data_type_size_);
1780 add(reg_off_c_, simd_w * acc_type_size_);
1781
1782 dec(reg_C_);
1783 jnz(label_C);
1784 }
1785 }
1786
1787 void compute_nspc() {
1788 mov(reg_C_, dword[PARAM_ADDR(C)]);
1789
1790 constexpr int max_of_unrolled_c_blks
1791 = number_of_vmms_to_unrolling_variables_
1792 / number_of_unrolled_variables_;
1793 std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
1794
1795 for (int c_blks_to_unroll = max_of_unrolled_c_blks;
1796 c_blks_to_unroll > 0; --c_blks_to_unroll) {
1797 L(c_unroll_label[c_blks_to_unroll]);
1798 {
1799 cmp(reg_C_, c_blks_to_unroll);
1800 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
1801
1802 mov(reg_off_dat_, reg_off_dat_save_);
1803
1804 load_mean(c_blks_to_unroll);
1805 zeroise_diff_beta_and_diff_gamma(c_blks_to_unroll);
1806
1807 Label label_S;
1808 mov(reg_S_, dword[PARAM_ADDR(S)]);
1809 L(label_S);
1810 {
1811 compute_diff_beta_and_diff_gamma(c_blks_to_unroll);
1812
1813 add(reg_off_dat_, stride_S_ * data_type_size_);
1814
1815 dec(reg_S_);
1816 jnz(label_S);
1817 }
1818
1819 load_and_prepare_sqrtvar(c_blks_to_unroll);
1820 store_diff_beta_and_diff_gamma(c_blks_to_unroll);
1821
1822 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
1823 add(reg_off_dat_save_,
1824 c_blks_to_unroll * stride_C_ * data_type_size_);
1825
1826 sub(reg_C_, c_blks_to_unroll);
1827 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
1828 }
1829 }
1830 L(c_unroll_label[0]);
1831 }
1832
1833 void compute() {
1834 Label label_N;
1835 mov(reg_N_, dword[PARAM_ADDR(N)]);
1836 L(label_N);
1837 {
1838 xor_(reg_off_dat_save_, reg_off_dat_save_);
1839 xor_(reg_off_c_, reg_off_c_);
1840
1841 tag_kind_ == jit_memory_tag_kind_t::nspc ? compute_nspc()
1842 : compute_blocked();
1843
1844 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1845 xor_(reg_off_dat_save_, reg_off_dat_save_);
1846 xor_(reg_off_c_, reg_off_c_);
1847 add(reg_off_dat_save_, vlen / 2);
1848 add(reg_off_c_, vlen / 2);
1849
1850 compute_blocked();
1851 }
1852
1853 add(reg_ptr_src_, stride_N_ * data_type_size_);
1854 add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1855 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1856
1857 dec(reg_N_);
1858 jnz(label_N);
1859 }
1860 }
1861
1862 jit_bnorm_bwd_diff_ss_t(const batch_normalization_pd_t *pd,
1863 const jit_memory_tag_kind_t tag_kind)
1864 : jit_generator(jit_name())
1865 , pd_(pd)
1866 , tag_kind_(tag_kind)
1867 , vlen(get_vlen<isa>(tag_kind))
1868 , simd_w(get_simd_w<isa>(tag_kind))
1869 , jit_tail_(pd, this, reg_tmp_, reg_blk_has_tail_, reg_C_, vtail_mask_,
1870 ktail_mask_)
1871 , jit_relu_(pd, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1872 vstore_mask_, kstore_mask_)
1873 , helper_vmovups_(pd, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1874 static_assert(utils::one_of(isa, sse41, avx2, avx512_core),
1875 "unsupported isa");
1876
1877 std::tie(stride_N_, stride_S_, stride_C_)
1878 = get_data_strides<isa>(pd_, tag_kind);
1879
1880 data_type_size_ = types::data_type_size(pd->src_md()->data_type);
1881 acc_type_size_ = sizeof(acc_data_t);
1882 }
1883
1884 void generate() override {
1885 preamble();
1886 load_common_params();
1887 jit_relu_.bwd_prepare_relu();
1888 jit_tail_.prepare_tail();
1889 zeroise();
1890 compute();
1891 postamble();
1892 }
1893};
1894
1895namespace bnorm_tbb_impl {
1896
1897template <cpu_isa_t isa>
1898struct driver_t : public c_compatible {
1899private:
1900 struct bnorm_dims_t {
1901 dim_t N, C, S;
1902 dim_t glob;
1903 };
1904
1905 DNNL_DISALLOW_COPY_AND_ASSIGN(driver_t);
1906
1907public:
1908 driver_t(const batch_normalization_pd_t *pd,
1909 const jit_memory_tag_kind_t tag_kind)
1910 : pd_(pd), tag_kind_(tag_kind), simd_w(get_simd_w<isa>(tag_kind)) {
1911 nthr_ = dnnl_get_max_threads();
1912 N_ = pd_->MB();
1913 S_ = pd_->D() * pd_->H() * pd_->W();
1914 C_ = pd_->C();
1915 C_blks_ = get_c_padded(pd_) / simd_w;
1916
1917 const size_t l3_size = platform::get_per_core_cache_size(3) * nthr_ / 2;
1918 int num_tensors = pd_->is_fwd() ? 1 : 2;
1919 dt_size_ = types::data_type_size(pd_->src_md()->data_type);
1920 const size_t working_set_size
1921 = dt_size_ * N_ * S_ * simd_w * num_tensors;
1922
1923 do_blocking_ = tag_kind_ == jit_memory_tag_kind_t::nspc
1924 ? false
1925 : working_set_size * C_blks_ >= l3_size / 2 && l3_size > 0;
1926
1927 if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
1928 if (normalize_only(pd_)) {
1929 // blocks have to fit in a 4rth of L1 so that they don't get evicted
1930 // There are at most 6 tensors: src, dst, mean, var, scale, shift
1931 dim_t n_tensors = 2 + pd_->use_scale() + pd_->use_shift();
1932 C_blk_step_ = utils::saturate<dim_t>(1, C_blks_,
1933 platform::get_per_core_cache_size(1)
1934 / get_vlen<isa>(jit_memory_tag_kind_t::nspc)
1935 / n_tensors);
1936 } else
1937 C_blk_step_ = C_blks_;
1938 } else {
1939 C_blk_step_ = utils::saturate<dim_t>(
1940 1, C_blks_, l3_size / working_set_size);
1941 }
1942 }
1943
1944 status_t create_kernel() {
1945 if (pd_->is_fwd()) {
1946 CHECK(safe_ptr_assign(
1947 ker_fwd_, new jit_bnorm_fwd_t<isa>(pd_, tag_kind_)));
1948 CHECK(ker_fwd_->create_kernel());
1949 if (!pd_->stats_is_src()) {
1950 CHECK(safe_ptr_assign(ker_fwd_mean_,
1951 new jit_bnorm_fwd_mean_t<isa>(pd_, tag_kind_)));
1952 CHECK(safe_ptr_assign(ker_fwd_var_,
1953 new jit_bnorm_fwd_var_t<isa>(pd_, tag_kind_)));
1954 CHECK(ker_fwd_mean_->create_kernel());
1955 CHECK(ker_fwd_var_->create_kernel());
1956 }
1957 } else {
1958 CHECK(safe_ptr_assign(
1959 ker_bwd_, new jit_bnorm_bwd_t<isa>(pd_, tag_kind_)));
1960 CHECK(safe_ptr_assign(ker_bwd_diff_ss_,
1961 new jit_bnorm_bwd_diff_ss_t<isa>(pd_, tag_kind_)));
1962 CHECK(ker_bwd_->create_kernel());
1963 CHECK(ker_bwd_diff_ss_->create_kernel());
1964 }
1965 return status::success;
1966 }
1967
1968 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1969 const batch_normalization_pd_t *pd) {
1970
1971 int nthrs = dnnl_get_max_threads();
1972 int C_PADDED = get_c_padded(pd);
1973
1974 auto sbuf_sz = use_tmp_stats(pd) * 2 * C_PADDED;
1975 auto pbuf_sz
1976 = (use_tmp_diff_scale(pd) + use_tmp_diff_shift(pd)) * C_PADDED;
1977 auto rbuf_sz = (pd->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1978
1979 scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz);
1980 scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz);
1981 scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz);
1982 }
1983
1984 void exec_fwd_step_stats(const dim_t C_blks, const bnorm_dims_t &nthr,
1985 const void *src, acc_data_t *mean, acc_data_t *var,
1986 acc_data_t *rbuf, bool blk_has_tail) {
1987 size_t stride_C, stride_N, stride_S;
1988 std::tie(stride_N, stride_S, stride_C)
1989 = get_data_strides<isa>(pd_, tag_kind_);
1990
1991 const int nthr_NS = nthr.N * nthr.S;
1992 const bool need_reduction = nthr_NS > 1;
1993 const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
1994
1995 const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
1996
1997 auto reduce = [&](acc_data_t *stat, acc_data_t *r_stat) {
1998 if (!need_reduction) return;
1999 acc_data_t *loc_stat = r_stat;
2000
2001 for (dim_t c = 0; c < size_C_stat; ++c)
2002 stat[c] = loc_stat[c];
2003
2004 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
2005 loc_stat += size_C_stat;
2006 for (dim_t c = 0; c < size_C_stat; ++c)
2007 stat[c] += loc_stat[c];
2008 }
2009
2010 for (dim_t c = 0; c < size_C_stat; ++c)
2011 stat[c] /= N_ * S_;
2012 };
2013
2014 // find local mean
2015 acc_data_t *r_mean = need_reduction ? rbuf : mean;
2016 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2017 assert(nthr_glob == nthr.glob);
2018 const auto ithr = map_thread(ithr_glob, nthr);
2019 bnorm_dims_t start, stop;
2020 work_distribution(C_blks, ithr, nthr, start, stop);
2021
2022 auto c = typename jit_bnorm_fwd_mean_t<isa>::call_params_t();
2023 c.N = stop.N - start.N;
2024 c.C = stop.C - start.C;
2025 c.S = stop.S - start.S;
2026
2027 const size_t d_off = start.N * stride_N + start.C * stride_C
2028 + start.S * stride_S;
2029 c.src = (void *)((char *)src + d_off * dt_size_);
2030 const int ithr_NS = ithr.N * nthr.S + ithr.S;
2031 c.mean = &r_mean[ithr_NS * size_C_stat + start.C * simd_w];
2032 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2033 c.do_normalise = !need_reduction;
2034 (*ker_fwd_mean_)(&c);
2035 });
2036
2037 // mean reduction
2038 reduce(mean, r_mean);
2039
2040 // find local var
2041 acc_data_t *r_var = need_reduction ? rbuf : var;
2042 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2043 assert(nthr_glob == nthr.glob);
2044 const auto ithr = map_thread(ithr_glob, nthr);
2045 bnorm_dims_t start, stop;
2046 work_distribution(C_blks, ithr, nthr, start, stop);
2047
2048 auto c = typename jit_bnorm_fwd_var_t<isa>::call_params_t();
2049 c.N = stop.N - start.N;
2050 c.C = stop.C - start.C;
2051 c.S = stop.S - start.S;
2052
2053 const size_t d_off = start.N * stride_N + start.C * stride_C
2054 + start.S * stride_S;
2055 c.src = (void *)((char *)src + d_off * dt_size_);
2056 const int ithr_NS = ithr.N * nthr.S + ithr.S;
2057 c.mean = &mean[start.C * simd_w];
2058 c.var = &r_var[ithr_NS * size_C_stat + start.C * simd_w];
2059 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2060 c.do_normalise = !need_reduction;
2061 (*ker_fwd_var_)(&c);
2062 });
2063
2064 // var reduction
2065 reduce(var, r_var);
2066 }
2067
2068 void exec_fwd_step_normalization(const dim_t C_blks,
2069 const bnorm_dims_t &nthr, const void *src, void *dst,
2070 const acc_data_t *scale, const acc_data_t *shift,
2071 const acc_data_t *mean, const acc_data_t *var, uint8_t *ws,
2072 bool blk_has_tail) {
2073 size_t stride_C, stride_N, stride_S;
2074 std::tie(stride_N, stride_S, stride_C)
2075 = get_data_strides<isa>(pd_, tag_kind_);
2076
2077 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2078 assert(nthr_glob == nthr.glob);
2079 const auto ithr = map_thread(ithr_glob, nthr);
2080 bnorm_dims_t start, stop;
2081 work_distribution(C_blks, ithr, nthr, start, stop);
2082
2083 auto c = typename jit_bnorm_fwd_t<isa>::call_params_t();
2084 c.N = stop.N - start.N;
2085 c.C = stop.C - start.C;
2086 c.S = stop.S - start.S;
2087 const size_t d_off = start.N * stride_N + start.C * stride_C
2088 + start.S * stride_S;
2089 c.src = (void *)((char *)src + d_off * dt_size_);
2090 c.dst = (void *)((char *)dst + d_off * dt_size_);
2091 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
2092 c.mean = &mean[start.C * simd_w];
2093 c.var = &var[start.C * simd_w];
2094 c.scale = scale ? &scale[start.C * simd_w] : nullptr;
2095 c.shift = shift ? &shift[start.C * simd_w] : nullptr;
2096 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2097 (*ker_fwd_)(&c);
2098 });
2099 }
2100
2101 void exec_fwd(const void *src, void *dst, const acc_data_t *scale,
2102 const acc_data_t *shift, acc_data_t *mean, acc_data_t *var,
2103 uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
2104 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
2105 if (use_tmp_stats(pd_)) {
2106 auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
2107 mean = sbuf;
2108 var = sbuf + C_blks_ * simd_w;
2109 }
2110
2111 size_t stride_C;
2112 std::tie(std::ignore, std::ignore, stride_C)
2113 = get_data_strides<isa>(pd_, tag_kind_);
2114
2115 dim_t C_blk_step = C_blk_step_;
2116 auto nthr = bnorm_dims_t();
2117
2118 thread_distribution(C_blk_step, nthr);
2119
2120 for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
2121 if (C_blk_st + C_blk_step > C_blks_) {
2122 C_blk_step = C_blks_ - C_blk_st;
2123 thread_distribution(C_blk_step, nthr);
2124 }
2125
2126 if (!pd_->stats_is_src()) {
2127 exec_fwd_step_stats(C_blk_step, nthr,
2128 (void *)((char *)src
2129 + (C_blk_st * stride_C) * dt_size_),
2130 mean + C_blk_st * simd_w, var + C_blk_st * simd_w, rbuf,
2131 (C_blk_st + C_blk_step) * simd_w > C_);
2132 }
2133 exec_fwd_step_normalization(C_blk_step, nthr,
2134 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2135 (void *)((char *)dst + (C_blk_st * stride_C) * dt_size_),
2136 scale + C_blk_st * simd_w, shift + C_blk_st * simd_w,
2137 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2138 ws + C_blk_st * stride_C / bits_per_byte,
2139 (C_blk_st + C_blk_step) * simd_w > C_);
2140 }
2141 }
2142
2143 void exec_bwd_step_diff_ss(const dim_t C_blks, const bnorm_dims_t &nthr,
2144 const void *src, const void *diff_dst, const acc_data_t *mean,
2145 const acc_data_t *var, const uint8_t *ws, acc_data_t *diff_scale,
2146 acc_data_t *diff_shift, acc_data_t *rbuf, bool blk_has_tail) {
2147 size_t stride_C, stride_N, stride_S;
2148 std::tie(stride_N, stride_S, stride_C)
2149 = get_data_strides<isa>(pd_, tag_kind_);
2150
2151 const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
2152 const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
2153
2154 const int nthr_NS = nthr.N * nthr.S;
2155 const bool need_reduction = nthr_NS > 1;
2156
2157 acc_data_t *diff_gamma = diff_scale;
2158 acc_data_t *diff_beta = diff_shift;
2159
2160 acc_data_t *const r_diff_gamma = need_reduction ? rbuf : diff_gamma;
2161 acc_data_t *const r_diff_beta
2162 = need_reduction ? rbuf + nthr_NS * size_C_stat : diff_beta;
2163
2164 auto reduce = [&]() {
2165 if (!need_reduction) return;
2166
2167 // diff_gamma
2168 const acc_data_t *loc_diff_gamma = r_diff_gamma;
2169 for (dim_t c = 0; c < size_C_stat; ++c)
2170 diff_gamma[c] = loc_diff_gamma[c];
2171 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
2172 loc_diff_gamma += size_C_stat;
2173 for (dim_t c = 0; c < size_C_stat; ++c)
2174 diff_gamma[c] += loc_diff_gamma[c];
2175 }
2176
2177 // diff_beta
2178 const acc_data_t *loc_diff_beta = r_diff_beta;
2179 for (dim_t c = 0; c < size_C_stat; ++c)
2180 diff_beta[c] = loc_diff_beta[c];
2181 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
2182 loc_diff_beta += size_C_stat;
2183 for (dim_t c = 0; c < size_C_stat; ++c)
2184 diff_beta[c] += loc_diff_beta[c];
2185 }
2186 };
2187
2188 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2189 assert(nthr_glob == nthr.glob);
2190 const auto ithr = map_thread(ithr_glob, nthr);
2191 bnorm_dims_t start, stop;
2192 work_distribution(C_blks, ithr, nthr, start, stop);
2193
2194 const int ithr_NS = ithr.N * nthr.S + ithr.S;
2195 acc_data_t *loc_diff_gamma = &r_diff_gamma[ithr_NS * size_C_stat];
2196 acc_data_t *loc_diff_beta = &r_diff_beta[ithr_NS * size_C_stat];
2197
2198 auto c = typename jit_bnorm_bwd_diff_ss_t<isa>::call_params_t();
2199 c.N = stop.N - start.N;
2200 c.C = stop.C - start.C;
2201 c.S = stop.S - start.S;
2202
2203 const size_t d_off = start.N * stride_N + start.C * stride_C
2204 + start.S * stride_S;
2205 c.src = (void *)((char *)src + d_off * dt_size_);
2206 c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
2207 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
2208 c.mean = &mean[start.C * simd_w];
2209 c.var = &var[start.C * simd_w];
2210 c.diff_gamma = &loc_diff_gamma[start.C * simd_w];
2211 c.diff_beta = &loc_diff_beta[start.C * simd_w];
2212 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2213
2214 (*ker_bwd_diff_ss_)(&c);
2215 });
2216
2217 reduce();
2218 }
2219
2220 void exec_bwd_step_normalization(const dim_t C_blks,
2221 const bnorm_dims_t &nthr, const void *src, void *diff_src,
2222 const void *diff_dst, const acc_data_t *mean, const acc_data_t *var,
2223 const uint8_t *ws, const acc_data_t *scale,
2224 const acc_data_t *diff_scale, const acc_data_t *diff_shift,
2225 bool blk_has_tail) {
2226 size_t stride_C, stride_N, stride_S;
2227 std::tie(stride_N, stride_S, stride_C)
2228 = get_data_strides<isa>(pd_, tag_kind_);
2229
2230 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2231 assert(nthr_glob == nthr.glob);
2232 const auto ithr = map_thread(ithr_glob, nthr);
2233 bnorm_dims_t start, stop;
2234 work_distribution(C_blks, ithr, nthr, start, stop);
2235
2236 auto c = typename jit_bnorm_bwd_t<isa>::call_params_t();
2237 c.N = stop.N - start.N;
2238 c.C = stop.C - start.C;
2239 c.S = stop.S - start.S;
2240
2241 const size_t d_off = start.N * stride_N + start.C * stride_C
2242 + start.S * stride_S;
2243 c.src = (void *)((char *)src + d_off * dt_size_);
2244 c.diff_src = (void *)((char *)diff_src + d_off * dt_size_);
2245 c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
2246 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
2247 c.mean = &mean[start.C * simd_w];
2248 c.var = &var[start.C * simd_w];
2249 c.scale = scale ? &scale[start.C * simd_w] : nullptr;
2250 c.diff_scale = &diff_scale[start.C * simd_w];
2251 c.diff_shift = &diff_shift[start.C * simd_w];
2252 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2253
2254 (*ker_bwd_)(&c);
2255 });
2256 }
2257
2258 void exec_bwd(const void *src, void *diff_src, const void *diff_dst,
2259 const acc_data_t *scale, acc_data_t *diff_scale,
2260 acc_data_t *diff_shift, const acc_data_t *mean,
2261 const acc_data_t *var, const uint8_t *ws,
2262 const memory_tracking::grantor_t &scratchpad) {
2263 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
2264 if (use_tmp_diff_scale(pd_)) {
2265 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2266 diff_scale = pbuf;
2267 }
2268 if (use_tmp_diff_shift(pd_)) {
2269 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2270 size_t shift_off = use_tmp_diff_scale(pd_) ? pd_->C() : 0;
2271 diff_shift = &pbuf[shift_off];
2272 }
2273
2274 size_t stride_C;
2275 std::tie(std::ignore, std::ignore, stride_C)
2276 = get_data_strides<isa>(pd_, tag_kind_);
2277
2278 dim_t C_blk_step = C_blk_step_;
2279 auto nthr = bnorm_dims_t();
2280
2281 thread_distribution(C_blk_step, nthr);
2282
2283 for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
2284 if (C_blk_st + C_blk_step > C_blks_) {
2285 C_blk_step = C_blks_ - C_blk_st;
2286 thread_distribution(C_blk_step, nthr);
2287 }
2288
2289 exec_bwd_step_diff_ss(C_blk_step, nthr,
2290 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2291 (void *)((char *)diff_dst
2292 + (C_blk_st * stride_C) * dt_size_),
2293 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2294 ws + C_blk_st * stride_C / bits_per_byte,
2295 diff_scale + C_blk_st * simd_w,
2296 diff_shift + C_blk_st * simd_w, rbuf,
2297 (C_blk_st + C_blk_step) * simd_w > C_);
2298
2299 exec_bwd_step_normalization(C_blk_step, nthr,
2300 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2301 (void *)((char *)diff_src
2302 + (C_blk_st * stride_C) * dt_size_),
2303 (void *)((char *)diff_dst
2304 + (C_blk_st * stride_C) * dt_size_),
2305 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2306 ws + C_blk_st * stride_C / bits_per_byte,
2307 scale + C_blk_st * simd_w, diff_scale + C_blk_st * simd_w,
2308 diff_shift + C_blk_st * simd_w,
2309 (C_blk_st + C_blk_step) * simd_w > C_);
2310 }
2311 }
2312
2313private:
2314 static bool use_tmp_stats(const batch_normalization_pd_t *pd) {
2315 return !pd->stats_is_src()
2316 && pd->desc()->prop_kind == prop_kind::forward_inference;
2317 }
2318
2319 static bool use_tmp_diff_scale(const batch_normalization_pd_t *pd) {
2320 return (!pd->is_fwd() && !pd->use_scale())
2321 || pd->desc()->prop_kind == prop_kind::backward_data;
2322 }
2323
2324 static bool use_tmp_diff_shift(const batch_normalization_pd_t *pd) {
2325 return (!pd->is_fwd() && !pd->use_shift())
2326 || pd->desc()->prop_kind == prop_kind::backward_data;
2327 }
2328
2329 void thread_distribution_nspc(dim_t C_blks, bnorm_dims_t &nthr) {
2330 if (normalize_only(pd_)) {
2331 // We want to keep some granularity on S so that we can
2332 // stay on 1 socket if possible, this is why we divide
2333 // work in chunks fitting in L2
2334
2335 dim_t n_stats_ss_tensors = pd_->use_scale() + pd_->use_shift();
2336 dim_t size_stats_ss_tensors = n_stats_ss_tensors * get_c_padded(pd_)
2337 * sizeof(acc_data_t);
2338
2339 dim_t size_src_dst = 2 * N_ * S_ * get_c_padded(pd_)
2340 * types::data_type_size(pd_->src_md()->data_type);
2341
2342 dim_t total_size = size_src_dst + size_stats_ss_tensors;
2343
2344 // Try to create at least nthr_ chunks for realtime inference. Not
2345 // enabled for throughput inference to avoid potential regressions
2346 // for multi-socket runs with threadpool runtime.
2347 // TODO: Enable for throughput inference.
2348 const int n_chunks_min = nthr_ <= 8 ? nthr_ : 1;
2349 const size_t l2_per_core = platform::get_per_core_cache_size(2);
2350 dim_t n_chunks
2351 = nstl::max<dim_t>(n_chunks_min, total_size / l2_per_core);
2352
2353 // we prioritize parallelization on N, then S, and finally C
2354 nthr.N = utils::saturate<dim_t>(1, N_, n_chunks);
2355 nthr.S = utils::saturate<dim_t>(1, S_, n_chunks / nthr.N);
2356 nthr.C = utils::saturate<dim_t>(
2357 1, C_blks, n_chunks / (nthr.N * nthr.S));
2358 } else {
2359 if ((nthr_ <= C_blks && nthr_ == 1) || C_blks <= 8)
2360 nthr.C = 1;
2361 else if (nthr_ >= 8 && C_blks <= 32)
2362 nthr.C = 8;
2363 else {
2364 nthr.C = math::gcd((dim_t)nthr_, C_blks);
2365 // Unroll by channels in JIT kernel
2366 if ((nthr.C == C_blks) || (nthr.C == nthr_)) nthr.C = 1;
2367 }
2368 nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C);
2369 nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N));
2370 }
2371 }
2372
2373 void thread_distribution(dim_t C_blks, bnorm_dims_t &nthr) {
2374 if (do_blocking_) {
2375 nthr.N = nstl::min<dim_t>(N_, nthr_);
2376 nthr.C = nstl::min<dim_t>(C_blks, nthr_ / nthr.N);
2377 nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N));
2378 } else {
2379 if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
2380 thread_distribution_nspc(C_blks, nthr);
2381 } else {
2382 nthr.C = math::gcd((dim_t)nthr_, C_blks);
2383 nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C);
2384 nthr.S = utils::saturate(
2385 (dim_t)1, S_, nthr_ / (nthr.C * nthr.N));
2386 }
2387 }
2388 nthr.glob = nthr.N * nthr.C * nthr.S;
2389 }
2390
2391 int map_thread_c(int ithr_glob, const bnorm_dims_t &nthr) {
2392 return ithr_glob / nthr.N / nthr.S;
2393 }
2394
2395 bnorm_dims_t map_thread(int ithr_glob, const bnorm_dims_t &nthr) {
2396 auto ithr = bnorm_dims_t();
2397 ithr.glob = ithr_glob;
2398 ithr.C = map_thread_c(ithr.glob, nthr);
2399 ithr.N = ithr.glob / nthr.S % nthr.N;
2400 ithr.S = ithr.glob % nthr.S;
2401 return ithr;
2402 }
2403
2404 void work_distribution_c(dim_t C_blks, int ithr_c, int nthr_c,
2405 dim_t &start_c, dim_t &stop_c) {
2406 balance211(C_blks, nthr_c, ithr_c, start_c, stop_c);
2407 }
2408
2409 void work_distribution(dim_t C_blks, const bnorm_dims_t &ithr,
2410 const bnorm_dims_t &nthr, bnorm_dims_t &start, bnorm_dims_t &stop) {
2411 work_distribution_c(C_blks, ithr.C, nthr.C, start.C, stop.C);
2412 balance211(N_, nthr.N, ithr.N, start.N, stop.N);
2413 balance211(S_, nthr.S, ithr.S, start.S, stop.S);
2414 }
2415
2416 const batch_normalization_pd_t *pd_;
2417 const jit_memory_tag_kind_t tag_kind_;
2418 const int simd_w;
2419
2420 bool do_blocking_;
2421
2422 int nthr_;
2423
2424 dim_t N_, S_; // MB, D * H *W
2425 dim_t C_, C_blks_; // C / simd_w
2426 dim_t C_blk_step_; // for C_blks = 0 .. C_blks_, += C_blk_step_
2427
2428 std::unique_ptr<jit_bnorm_fwd_t<isa>> ker_fwd_;
2429 std::unique_ptr<jit_bnorm_fwd_mean_t<isa>> ker_fwd_mean_;
2430 std::unique_ptr<jit_bnorm_fwd_var_t<isa>> ker_fwd_var_;
2431 std::unique_ptr<jit_bnorm_bwd_t<isa>> ker_bwd_;
2432 std::unique_ptr<jit_bnorm_bwd_diff_ss_t<isa>> ker_bwd_diff_ss_;
2433
2434 size_t dt_size_;
2435};
2436} // namespace bnorm_tbb_impl
2437
2438using namespace data_type;
2439using namespace format_tag;
2440using namespace utils;
2441
2442/* fwd */
2443template <cpu_isa_t isa>
2444status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::pd_t::init(
2445 engine_t *engine) {
2446 const bool ok = is_fwd() && mayiuse(isa) && !has_zero_dim_memory()
2447 && one_of(src_md()->data_type, f32, bf16, f16)
2448 && src_md()->data_type == dst_md()->data_type
2449 && IMPLICATION(src_md()->data_type == bf16,
2450 is_superset(isa, avx512_core)
2451 || (isa == avx2 && mayiuse(avx2_vnni_2)))
2452 // Note: re-using avx512_core/avx2 implementation for f16.
2453 // This is okay as currently, we do not support binary post-ops
2454 // for this primitive.
2455 && IMPLICATION(src_md()->data_type == f16,
2456 (is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16))
2457 || (isa == avx2 && mayiuse(avx2_vnni_2)))
2458 && check_scale_shift_data_type()
2459 && (attr()->has_default_values()
2460 || with_relu_post_op(is_training()))
2461 && set_default_formats_common()
2462 && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md());
2463 if (!ok) return status::unimplemented;
2464
2465 // BN+Add+Relu fusion is not currently implemented
2466 if (fuse_norm_add_relu()) return status::unimplemented;
2467
2468 const format_tag_t blocked_tag = is_superset(isa, avx512_core)
2469 ? utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)
2470 : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
2471
2472 const format_tag_t blocked_format
2473 = memory_desc_matches_tag(*src_md(), blocked_tag)
2474 ? blocked_tag
2475 : format_tag::undef;
2476 const format_tag_t nspc_format
2477 = memory_desc_matches_one_of_tag(*src_md(), nc, nwc, nhwc, ndhwc);
2478
2479 if (memory_desc_matches_tag(*dst_md(), blocked_format))
2480 tag_kind_ = jit_memory_tag_kind_t::blocked;
2481 else if (memory_desc_matches_tag(*dst_md(), nspc_format)) {
2482 tag_kind_ = jit_memory_tag_kind_t::nspc;
2483 const int simd_w = get_simd_w<isa>(tag_kind_);
2484 if (C() % simd_w != 0) return status::unimplemented;
2485 } else
2486 return status::unimplemented;
2487
2488 // AVX2 only supports xf16 on plain layout and inference
2489 if (utils::one_of(src_md()->data_type, bf16, f16) && isa == avx2
2490 && (is_training()
2491 || !memory_desc_matches_tag(*dst_md(), nspc_format)))
2492 return status::unimplemented;
2493
2494 const bool isa_supports_avx2 = is_superset(isa, avx2);
2495 if (is_training() && fuse_norm_relu()) {
2496 if (!isa_supports_avx2) return status::unimplemented;
2497 init_default_ws(1);
2498 }
2499
2500 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2501 && !isa_supports_avx2)
2502 return status::unimplemented;
2503
2504 auto scratchpad = scratchpad_registry().registrar();
2505 bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2506
2507 return status::success;
2508}
2509
2510template <cpu_isa_t isa>
2511jit_uni_tbb_batch_normalization_fwd_t<
2512 isa>::jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd)
2513 : primitive_t(apd) {}
2514
2515template <cpu_isa_t isa>
2516status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::init(engine_t *engine) {
2517 CHECK(safe_ptr_assign(bnorm_driver_,
2518 new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2519 return bnorm_driver_->create_kernel();
2520}
2521
2522template <cpu_isa_t isa>
2523status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::execute(
2524 const exec_ctx_t &ctx) const {
2525
2526 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2527 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
2528 auto shift = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT);
2529
2530 auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>(
2531 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN))
2532 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
2533 auto var = pd()->stats_is_src()
2534 ? const_cast<acc_data_t *>(
2535 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE))
2536 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
2537
2538 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
2539 auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
2540
2541 auto scratchpad = ctx.get_scratchpad_grantor();
2542
2543 bnorm_driver_->exec_fwd(src, dst, scale, shift, mean, var, ws, scratchpad);
2544
2545 return status::success;
2546}
2547
2548template <cpu_isa_t isa>
2549jit_uni_tbb_batch_normalization_fwd_t<
2550 isa>::~jit_uni_tbb_batch_normalization_fwd_t()
2551 = default;
2552
2553template struct jit_uni_tbb_batch_normalization_fwd_t<sse41>;
2554template struct jit_uni_tbb_batch_normalization_fwd_t<avx2>;
2555template struct jit_uni_tbb_batch_normalization_fwd_t<avx512_core>;
2556
2557/* bwd */
2558template <cpu_isa_t isa>
2559status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::pd_t::init(
2560 engine_t *engine) {
2561 bool ok = !is_fwd() && mayiuse(isa) && !has_zero_dim_memory()
2562 && one_of(src_md()->data_type, f32, bf16, f16)
2563 && src_md()->data_type == diff_src_md()->data_type
2564 && diff_src_md()->data_type == diff_dst_md()->data_type
2565 && IMPLICATION(
2566 src_md()->data_type == bf16, is_superset(isa, avx512_core))
2567 // Note: re-using avx512_core implementation for f16. This is okay
2568 // as currently, we do not support binary post-ops for this
2569 // primitive.
2570 && IMPLICATION(src_md()->data_type == f16,
2571 is_superset(isa, avx512_core) && mayiuse(avx512_core_fp16))
2572 && check_scale_shift_data_type() && attr()->has_default_values()
2573 && set_default_formats_common()
2574 && memory_desc_wrapper(diff_src_md())
2575 == memory_desc_wrapper(diff_dst_md());
2576 if (!ok) return status::unimplemented;
2577
2578 // BN+Add+Relu fusion is not currently implemented
2579 if (fuse_norm_add_relu()) return status::unimplemented;
2580
2581 const format_tag_t blocked_tag = is_superset(isa, avx512_core)
2582 ? utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)
2583 : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
2584
2585 const format_tag_t blocked_format
2586 = memory_desc_matches_tag(*src_md(), blocked_tag)
2587 ? blocked_tag
2588 : format_tag::undef;
2589 const format_tag_t nspc_format
2590 = memory_desc_matches_one_of_tag(*src_md(), nc, nwc, nhwc, ndhwc);
2591
2592 if (memory_desc_matches_tag(*diff_src_md(), blocked_format))
2593 tag_kind_ = jit_memory_tag_kind_t::blocked;
2594 else if (memory_desc_matches_tag(*diff_src_md(), nspc_format)) {
2595 tag_kind_ = jit_memory_tag_kind_t::nspc;
2596 const int simd_w = get_simd_w<isa>(tag_kind_);
2597 if (C() % simd_w != 0) return status::unimplemented;
2598 } else
2599 return status::unimplemented;
2600
2601 const bool isa_supports_avx2 = is_superset(isa, avx2);
2602 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2603 && !isa_supports_avx2)
2604 return status::unimplemented;
2605
2606 if (fuse_norm_relu()) {
2607 if (!isa_supports_avx2) return status::unimplemented;
2608 init_default_ws(1);
2609 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
2610 }
2611
2612 auto scratchpad = scratchpad_registry().registrar();
2613 bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2614
2615 return status::success;
2616}
2617
2618template <cpu_isa_t isa>
2619jit_uni_tbb_batch_normalization_bwd_t<
2620 isa>::jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd)
2621 : primitive_t(apd) {}
2622
2623template <cpu_isa_t isa>
2624status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::init(engine_t *engine) {
2625 CHECK(safe_ptr_assign(bnorm_driver_,
2626 new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2627 return bnorm_driver_->create_kernel();
2628}
2629
2630template <cpu_isa_t isa>
2631status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::execute(
2632 const exec_ctx_t &ctx) const {
2633
2634 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2635 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
2636 auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
2637 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
2638 auto scale = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SCALE);
2639 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
2640
2641 auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
2642 auto diff_scale = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SCALE);
2643 auto diff_shift = CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT);
2644
2645 auto scratchpad = ctx.get_scratchpad_grantor();
2646
2647 bnorm_driver_->exec_bwd(src, diff_src, diff_dst, scale, diff_scale,
2648 diff_shift, mean, var, ws, scratchpad);
2649
2650 return status::success;
2651}
2652
2653template <cpu_isa_t isa>
2654jit_uni_tbb_batch_normalization_bwd_t<
2655 isa>::~jit_uni_tbb_batch_normalization_bwd_t()
2656 = default;
2657
2658template struct jit_uni_tbb_batch_normalization_bwd_t<sse41>;
2659template struct jit_uni_tbb_batch_normalization_bwd_t<avx2>;
2660template struct jit_uni_tbb_batch_normalization_bwd_t<avx512_core>;
2661} // namespace x64
2662} // namespace cpu
2663} // namespace impl
2664} // namespace dnnl
2665