1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/reorder.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/cpu_primitive.hpp"
26
27#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_uni_layer_normalization.hpp"
30#include "cpu/x64/utils/jit_io_helper.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37using namespace memory_tracking::names;
38using namespace data_type;
39using namespace Xbyak;
40
41cpu_isa_t get_io_isa(cpu_isa_t isa, bool has_f16) {
42 // re-using avx512_core instantiation for f16
43 return has_f16 && is_superset(isa, avx512_core) ? avx512_core_fp16 : isa;
44}
45
46template <cpu_isa_t isa>
47struct jit_stat_and_data_base_kernel_t : stat_and_data_kernel_t,
48 public jit_generator {
49 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_stat_and_data_kernel_t);
50
51 void operator()(const void *src, void *dst, const float *scale,
52 const float *shift, float *mean, float *var,
53 const float *src_scales, const float *dst_scales,
54 const size_t block_size) const override {
55 ker_args_t args;
56 args.src = src;
57 args.dst = dst;
58 args.scale = scale;
59 args.shift = shift;
60 args.mean = mean;
61 args.var = var;
62 args.src_scales = src_scales;
63 args.dst_scales = dst_scales;
64 args.block_size
65 = block_size * C_ * types::data_type_size(src_d_.data_type());
66 args.eps = eps_;
67 jit_generator::operator()(&args);
68 }
69
70 status_t create_kernel() override { return jit_generator::create_kernel(); }
71
72 jit_stat_and_data_base_kernel_t(const layer_normalization_pd_t *pd)
73 : stat_and_data_kernel_t(pd)
74 , jit_generator(jit_name())
75 , src_d_(pd_->src_md())
76 , dst_d_(pd_->dst_md())
77 , simd_w_(vlen / sizeof(float))
78 , C_(pd_->norm_axis())
79 , axis_simd_full_(C_ / simd_w_)
80 , axis_simd_tail_(C_ % simd_w_)
81 , use_scale_(pd_->use_scale())
82 , use_shift_(pd_->use_shift())
83 , save_stats_(pd_->is_training())
84 , calculate_stats_(!pd_->stats_are_src())
85 , eps_(pd_->desc()->layer_norm_epsilon) {
86
87 io::io_conf_t io_conf;
88 io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_,
89 tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp);
90 io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx,
91 bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp,
92 bf16_emu_zmm_4_idx);
93 io::io_saturation_conf_t io_saturation_conf(
94 vmm_zero.getIdx(), vmm_saturation_ubound.getIdx(), reg_tmp);
95 const auto io_isa = get_io_isa(isa,
96 utils::one_of(f16, src_d_.data_type(), dst_d_.data_type()));
97 io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa,
98 {src_d_.data_type(), dst_d_.data_type(), f32 /* stats */},
99 io_conf, io_tail_conf, io_bf16_conf,
100 {{dst_d_.data_type(), io_saturation_conf}});
101 }
102
103protected:
104 static constexpr int unroll_factor_ = 4;
105 using Vmm = typename cpu_isa_traits<isa>::Vmm;
106 const AddressFrame &vmmword
107 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
108 const int vlen = cpu_isa_traits<isa>::vlen;
109
110 struct ker_args_t {
111 const void *src;
112 void *dst;
113 const float *scale;
114 const float *shift;
115 const float *mean;
116 const float *var;
117 const float *src_scales;
118 const float *dst_scales;
119 size_t block_size;
120 float eps;
121 };
122
123 io::jit_io_multi_dt_helper_t<Vmm> io_;
124 const memory_desc_wrapper src_d_, dst_d_;
125 const size_t simd_w_;
126 const dim_t C_;
127 const dim_t axis_simd_full_;
128 const dim_t axis_simd_tail_;
129 const bool use_scale_;
130 const bool use_shift_;
131 const bool save_stats_;
132 const bool calculate_stats_;
133 const float eps_;
134
135 const Reg64 reg_param = abi_param1;
136 const Reg64 reg_src = rdx;
137 const Reg64 reg_dst = rax;
138 const Reg64 reg_mean = rbx;
139 const Reg64 reg_scale = r8;
140 const Reg64 reg_block_end = r9;
141 const Reg64 reg_eps = r10;
142 const Reg64 reg_tmp = r11;
143 const Reg64 reg_shift = r12;
144 const Reg64 reg_var = r13;
145 const Reg64 reg_src_scales = r14;
146 const Reg64 reg_dst_scales = r15;
147
148 const Vmm vmm_tail_mask = Vmm(0);
149 const Vmm vmm_zero = Vmm(4); // In unroll range, safe for dst compute.
150 const Vmm vmm_saturation_ubound
151 = Vmm(5); // In unroll range, safe for dst compute.
152 const Vmm vmm_combined_scales
153 = Vmm(6); // In unroll range, safe for dst compute.
154 const Vmm vmm_scale = Vmm(7); // In unroll range, safe for dst compute.
155 const Vmm vmm_shift = Vmm(8); // In unroll range, safe for dst compute.
156 const Vmm vmm_ones = Vmm(9);
157 const Vmm vmm_eps = Vmm(10);
158 const Vmm vmm_c = Vmm(11);
159 const Vmm vmm_mean = Vmm(12);
160 const Vmm vmm_inv_sqrtvar = Vmm(13);
161 const Vmm vmm_dst = Vmm(14);
162 const Vmm vmm_tmp = Vmm(15);
163 const Xmm xmm_tmp = Xmm(15);
164
165 const int bf16_emu_zmm_1_idx = 28;
166 const int bf16_emu_zmm_2_idx = 29;
167 const int bf16_emu_zmm_3_idx = 30;
168 const int bf16_emu_zmm_4_idx = 31;
169 const int tail_opmask_idx = 1;
170
171 Address src_ptr(size_t offt = 0) {
172 return vmmword[reg_src + offt * src_d_.data_type_size()];
173 }
174
175 Address dst_ptr(size_t offt = 0) {
176 return vmmword[reg_dst + offt * dst_d_.data_type_size()];
177 }
178
179 Address mean_ptr(size_t offt = 0) {
180 return vmmword[reg_mean + offt * sizeof(float)];
181 }
182
183 Address var_ptr(size_t offt = 0) {
184 return vmmword[reg_var + offt * sizeof(float)];
185 }
186
187 Address scale_ptr(size_t offt = 0) {
188 return vmmword[reg_scale + offt * sizeof(float)];
189 }
190
191 Address shift_ptr(size_t offt = 0) {
192 return vmmword[reg_shift + offt * sizeof(float)];
193 }
194
195 virtual void compute_var() = 0;
196 virtual void reduce(Vmm vmm_src, Vmm vmm_tmp) = 0;
197
198 template <typename F>
199 void compute(Vmm vmm_stat, F op) {
200 bool need_tail = false;
201 int base_idx = 1; // Preserve `0` for tail on AVX2.
202
203 uni_vpxor(Vmm(base_idx), Vmm(base_idx), Vmm(base_idx));
204 if (axis_simd_full_ > 0) {
205 const int unroll
206 = axis_simd_full_ >= unroll_factor_ ? unroll_factor_ : 1;
207 assert(math::is_pow2(unroll));
208
209 for (int i = base_idx + 1; i < base_idx + unroll; i++)
210 uni_vpxor(Vmm(i), Vmm(i), Vmm(i));
211
212 // unrolled loop
213 for (int i = 0; i < axis_simd_full_ / unroll; i++)
214 for (int j = base_idx; j < base_idx + unroll; j++) {
215 io_[src_d_.data_type()]->load(
216 src_ptr((i * unroll + j - base_idx) * simd_w_),
217 Vmm(j + unroll), need_tail);
218 op(Vmm(j), Vmm(j + unroll), need_tail);
219 }
220
221 // unrolled loop reduction
222 int n = unroll;
223 while (n > 1) {
224 for (int j = base_idx; j < base_idx + n / 2; j++)
225 uni_vaddps(Vmm(j), Vmm(j), Vmm(j + n / 2));
226 n = n / 2;
227 }
228
229 // unrolled loop remainder
230 for (int i = utils::rnd_dn(axis_simd_full_, unroll);
231 i < axis_simd_full_; i++) {
232 io_[src_d_.data_type()]->load(
233 src_ptr(i * simd_w_), Vmm(base_idx + 1), need_tail);
234 op(Vmm(base_idx), Vmm(base_idx + 1), need_tail);
235 }
236 }
237
238 if (axis_simd_tail_ > 0) {
239 need_tail = true;
240 // vector remainder
241 io_[src_d_.data_type()]->load(src_ptr(axis_simd_full_ * simd_w_),
242 Vmm(base_idx + 1), need_tail);
243 op(Vmm(base_idx), Vmm(base_idx + 1), need_tail);
244 }
245
246 reduce(Vmm(base_idx), Vmm(base_idx + 1));
247 uni_vdivps(Vmm(base_idx), Vmm(base_idx), vmm_c, vmm_tmp);
248 uni_vmovups(vmm_stat, Vmm(base_idx));
249 }
250
251 void compute_mean() {
252 compute(vmm_mean, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) {
253 uni_vaddps(vmm_dst, vmm_dst, vmm_src);
254 });
255 if (save_stats_) uni_vmovss(ptr[reg_mean], Xmm(vmm_mean.getIdx()));
256 }
257
258 void calculate_dst(size_t offt_elems, bool tail = false) {
259 if (use_scale_) {
260 io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail);
261 }
262 if (use_shift_) {
263 io_[f32]->load(shift_ptr(offt_elems), vmm_shift, tail);
264 }
265 io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_dst, tail);
266 uni_vsubps(vmm_dst, vmm_dst, vmm_mean);
267 uni_vmulps(vmm_dst, vmm_dst, vmm_inv_sqrtvar);
268 if (use_scale_ && use_shift_)
269 uni_vfmadd213ps(vmm_dst, vmm_scale, vmm_shift);
270 else {
271 if (use_scale_) uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
272 if (use_shift_) uni_vaddps(vmm_dst, vmm_dst, vmm_shift);
273 }
274 uni_vmulps(vmm_dst, vmm_dst, vmm_combined_scales);
275 io_[dst_d_.data_type()]->store(vmm_dst, dst_ptr(offt_elems), tail);
276 }
277
278 void generate() override {
279 const size_t c_src_size
280 = C_ * types::data_type_size(src_d_.data_type());
281 const size_t c_dst_size
282 = C_ * types::data_type_size(dst_d_.data_type());
283 static const size_t float_size = types::data_type_size(f32);
284
285 preamble();
286
287 io_.init_bf16();
288 if (axis_simd_tail_) io_.prepare_tail_mask();
289
290#define PARAM_OFF(x) offsetof(ker_args_t, x)
291 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
292 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
293 mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
294 mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
295 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
296 mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
297 mov(reg_src_scales, ptr[reg_param + PARAM_OFF(src_scales)]);
298 mov(reg_dst_scales, ptr[reg_param + PARAM_OFF(dst_scales)]);
299 mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]);
300 mov(reg_eps, ptr[reg_param + PARAM_OFF(eps)]);
301#undef PARAM_OFF
302
303 uni_vmovq(xmm_tmp, reg_eps);
304 uni_vbroadcastss(vmm_eps, xmm_tmp);
305 mov(reg_tmp, float2int(1.f));
306 uni_vmovq(xmm_tmp, reg_tmp);
307 uni_vbroadcastss(vmm_ones, xmm_tmp);
308 mov(reg_tmp, float2int(C_));
309 uni_vmovq(xmm_tmp, reg_tmp);
310 uni_vbroadcastss(vmm_c, xmm_tmp);
311
312 // add block_start to block_size to define block_end
313 add(reg_block_end, reg_src);
314
315 Label unroll_loop, end;
316 L(unroll_loop);
317 {
318 cmp(reg_block_end, reg_src);
319 jle(end, T_NEAR);
320
321 if (calculate_stats_) {
322 // compute stats
323 compute_mean();
324 compute_var();
325 } else {
326 // read mean and var from input
327 uni_vmovss(xmm_tmp, dword[reg_mean]);
328 uni_vbroadcastss(vmm_mean, xmm_tmp);
329 uni_vmovss(xmm_tmp, dword[reg_var]);
330 uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp);
331 }
332
333 // calculate inv_sqrtvar
334 uni_vaddps(vmm_inv_sqrtvar, vmm_inv_sqrtvar, vmm_eps);
335 uni_vsqrtps(vmm_inv_sqrtvar, vmm_inv_sqrtvar);
336 uni_vdivps(vmm_inv_sqrtvar, vmm_ones, vmm_inv_sqrtvar, vmm_tmp);
337
338 // precompute and broadcast scales (in case of runtime)
339 uni_vmovss(xmm_tmp, dword[reg_src_scales]);
340 uni_vbroadcastss(vmm_combined_scales, xmm_tmp);
341 uni_vmovss(xmm_tmp, dword[reg_dst_scales]);
342 uni_vbroadcastss(vmm_tmp, xmm_tmp);
343 uni_vmulps(vmm_combined_scales, vmm_combined_scales, vmm_tmp);
344 io_.init_saturate_f32({dst_d_.data_type()});
345
346 // calculate dst
347 for (int i = 0; i < axis_simd_full_; i++)
348 calculate_dst(i * simd_w_);
349 if (axis_simd_tail_) calculate_dst(axis_simd_full_ * simd_w_, true);
350
351 add(reg_src, c_src_size);
352 add(reg_dst, c_dst_size);
353 add(reg_mean, float_size);
354 add(reg_var, float_size);
355 jmp(unroll_loop);
356 }
357 L(end);
358
359 postamble();
360 }
361};
362
363template <cpu_isa_t isa>
364struct jit_stat_and_data_kernel_t;
365
366template <>
367struct jit_stat_and_data_kernel_t<avx512_core>
368 : public jit_stat_and_data_base_kernel_t<avx512_core> {
369
370 using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t;
371
372 void compute_var() override {
373 compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) {
374 // Need to preserve zeros after subtract for correct answer.
375 if (!need_tail)
376 uni_vsubps(vmm_src, vmm_src, vmm_mean);
377 else
378 uni_vsubps(vmm_src | Opmask(tail_opmask_idx) | T_z, vmm_src,
379 vmm_mean);
380 uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src);
381 });
382 if (save_stats_)
383 uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx()));
384 }
385
386 void reduce(Vmm vmm_src, Vmm vmm_tmp) override {
387 vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0x4E); // 256-bit shuffle
388 vaddps(vmm_src, vmm_src, vmm_tmp);
389 vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0xB1); // 128/256-bit shuffle
390 vaddps(vmm_src, vmm_src, vmm_tmp);
391 vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle
392 vaddps(vmm_src, vmm_src, vmm_tmp);
393 vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle
394 vaddps(vmm_src, vmm_src, vmm_tmp);
395 }
396};
397
398template <>
399struct jit_stat_and_data_kernel_t<avx2>
400 : jit_stat_and_data_base_kernel_t<avx2> {
401
402 using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t;
403
404 void compute_var() override {
405 compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) {
406 // Need to preserve zeros after subtract for correct answer.
407 if (!need_tail)
408 uni_vsubps(vmm_src, vmm_src, vmm_mean);
409 else {
410 // We need to call tail version once, it's fine to use `vmm_tmp`
411 uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
412 uni_vblendvps(vmm_tmp, vmm_tmp, vmm_mean, vmm_tail_mask);
413 uni_vsubps(vmm_src, vmm_src, vmm_tmp);
414 }
415 uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src);
416 });
417 if (save_stats_)
418 uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx()));
419 }
420
421 void reduce(Vmm vmm_src, Vmm vmm_tmp) override {
422 vperm2f128(vmm_tmp, vmm_src, vmm_src, 0x1); // 128/256-bit shuffle
423 vaddps(vmm_src, vmm_src, vmm_tmp);
424 vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle
425 vaddps(vmm_src, vmm_src, vmm_tmp);
426 vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle
427 vaddps(vmm_src, vmm_src, vmm_tmp);
428 }
429};
430
431template <>
432struct jit_stat_and_data_kernel_t<sse41>
433 : jit_stat_and_data_base_kernel_t<sse41> {
434
435 using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t;
436
437 void compute_var() override {
438 compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) {
439 // Need to preserve zeros after subtract for correct answer.
440 if (!need_tail)
441 uni_vsubps(vmm_src, vmm_src, vmm_mean);
442 else {
443 // We need to call tail version once, it's fine to use `vmm_tmp`
444 uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
445 uni_vblendvps(vmm_tmp, vmm_tmp, vmm_mean, vmm_tail_mask);
446 uni_vsubps(vmm_src, vmm_src, vmm_tmp);
447 }
448 uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src);
449 });
450 if (save_stats_)
451 uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx()));
452 }
453
454 void reduce(Vmm vmm_src, Vmm vmm_tmp) override {
455 uni_vmovups(vmm_tmp, vmm_src);
456 shufps(vmm_tmp, vmm_tmp, 0x4E); // 64/128-bit shuffle
457 uni_vaddps(vmm_src, vmm_src, vmm_tmp);
458 uni_vmovups(vmm_tmp, vmm_src);
459 shufps(vmm_tmp, vmm_tmp, 0xB1); // 32/64-bit shuffle
460 uni_vaddps(vmm_src, vmm_src, vmm_tmp);
461 }
462};
463
464stat_and_data_kernel_t *stat_and_data_kernel_t::create(
465 const layer_normalization_pd_t *pd) {
466 if (mayiuse(avx512_core)) {
467 return new jit_stat_and_data_kernel_t<avx512_core>(pd);
468 } else if (mayiuse(avx2)) {
469 return new jit_stat_and_data_kernel_t<avx2>(pd);
470 } else if (mayiuse(sse41)) {
471 return new jit_stat_and_data_kernel_t<sse41>(pd);
472 } else {
473 assert(!"kernel is empty.");
474 return nullptr;
475 }
476}
477
478template <cpu_isa_t isa>
479struct jit_diff_ss_kernel_t : diff_ss_kernel_t, public jit_generator {
480 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_diff_ss_kernel_t);
481
482 void operator()(const void *src, const void *diff_dst, float *diff_scale,
483 float *diff_shift, const float *mean, const float *var,
484 float *const inv_sqrtvar, const size_t block_size) const override {
485 ker_args_t args;
486 args.src = src;
487 args.diff_dst = diff_dst;
488 args.diff_scale = diff_scale;
489 args.diff_shift = diff_shift;
490 args.mean = mean;
491 for (size_t i = 0; i < block_size; i++) {
492#ifdef __INTEL_COMPILER
493 //Without volatile ICC with -O2 & -O3 optimizes out denominator from
494 //inv_sqrtvar and computes 1/denom with lower precision
495 const volatile float denom = sqrtf(var[i] + eps_);
496#else
497 const float denom = sqrtf(var[i] + eps_);
498#endif
499 inv_sqrtvar[i] = 1.f / denom;
500 }
501 args.inv_sqrtvar = inv_sqrtvar;
502 args.block_size
503 = block_size * C_ * types::data_type_size(src_d_.data_type());
504 jit_generator::operator()(&args);
505 }
506
507 status_t create_kernel() override { return jit_generator::create_kernel(); }
508
509 jit_diff_ss_kernel_t(const layer_normalization_pd_t *pd)
510 : diff_ss_kernel_t(pd)
511 , jit_generator(jit_name())
512 , src_d_(pd_->src_md())
513 , d_dst_d_(pd_->diff_dst_md())
514 , simd_w_(vlen / sizeof(float))
515 , C_(pd_->norm_axis())
516 , axis_simd_full_(C_ / simd_w_)
517 , axis_simd_tail_(C_ % simd_w_)
518 , eps_(pd_->desc()->layer_norm_epsilon) {
519
520 io::io_conf_t io_conf;
521 io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_,
522 tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp);
523 io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx,
524 bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp,
525 bf16_emu_zmm_4_idx);
526 const auto io_isa = get_io_isa(isa,
527 utils::one_of(f16, src_d_.data_type(), d_dst_d_.data_type()));
528 io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa,
529 {src_d_.data_type(), d_dst_d_.data_type(), f32 /* stats */},
530 io_conf, io_tail_conf, io_bf16_conf);
531 }
532
533protected:
534 using Vmm = typename cpu_isa_traits<isa>::Vmm;
535 const AddressFrame &vmmword
536 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
537 const int vlen = cpu_isa_traits<isa>::vlen;
538
539 struct ker_args_t {
540 const void *src;
541 const void *diff_dst;
542 float *diff_scale;
543 float *diff_shift;
544 const float *mean;
545 const float *inv_sqrtvar;
546 size_t block_size;
547 };
548
549 io::jit_io_multi_dt_helper_t<Vmm> io_;
550 const memory_desc_wrapper src_d_, d_dst_d_;
551 const size_t simd_w_;
552 const dim_t C_;
553 const dim_t axis_simd_full_;
554 const dim_t axis_simd_tail_;
555 const float eps_;
556
557 const Reg64 reg_param = abi_param1;
558 const Reg64 reg_src = rdx;
559 const Reg64 reg_diff_dst = rax;
560 const Reg64 reg_mean = rbx;
561 const Reg64 reg_diff_scale = r8;
562 const Reg64 reg_block_end = r9;
563 const Reg64 reg_tmp = r11;
564 const Reg64 reg_diff_shift = r12;
565 const Reg64 reg_inv_sqrtvar = r13;
566
567 const Vmm vmm_tail_mask = Vmm(0);
568 const Xmm xmm_tmp = Xmm(9);
569 const Vmm vmm_inv_sqrtvar = Vmm(10);
570 const Vmm vmm_ddst = Vmm(11);
571 const Vmm vmm_dscale = Vmm(12);
572 const Vmm vmm_dshift = Vmm(13);
573 const Vmm vmm_src = Vmm(14);
574 const Vmm vmm_mean = Vmm(15);
575
576 const int bf16_emu_zmm_1_idx = 28;
577 const int bf16_emu_zmm_2_idx = 29;
578 const int bf16_emu_zmm_3_idx = 30;
579 const int bf16_emu_zmm_4_idx = 31;
580 const int tail_opmask_idx = 1;
581
582 Address src_ptr(size_t offt = 0) {
583 return vmmword[reg_src + offt * src_d_.data_type_size()];
584 }
585
586 Address d_dst_ptr(size_t offt = 0) {
587 return vmmword[reg_diff_dst + offt * d_dst_d_.data_type_size()];
588 }
589
590 Address d_scale_ptr(size_t offt = 0) {
591 return vmmword[reg_diff_scale + offt * sizeof(float)];
592 }
593
594 Address d_shift_ptr(size_t offt = 0) {
595 return vmmword[reg_diff_shift + offt * sizeof(float)];
596 }
597
598 void calculate_diff_scale_shift(size_t offt_elems, bool tail = false) {
599 io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail);
600 io_[f32]->load(d_scale_ptr(offt_elems), vmm_dscale, tail);
601 io_[f32]->load(d_shift_ptr(offt_elems), vmm_dshift, tail);
602 io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail);
603
604 uni_vaddps(vmm_dshift, vmm_dshift, vmm_ddst);
605 uni_vsubps(vmm_src, vmm_src, vmm_mean);
606 uni_vmulps(vmm_src, vmm_src, vmm_inv_sqrtvar);
607 uni_vfmadd231ps(vmm_dscale, vmm_src, vmm_ddst);
608
609 io_[f32]->store(vmm_dscale, d_scale_ptr(offt_elems), tail);
610 io_[f32]->store(vmm_dshift, d_shift_ptr(offt_elems), tail);
611 };
612
613 void generate() override {
614 const size_t c_src_size
615 = C_ * types::data_type_size(src_d_.data_type());
616 const size_t c_ddst_size
617 = C_ * types::data_type_size(d_dst_d_.data_type());
618 static const size_t float_size = types::data_type_size(f32);
619
620 preamble();
621
622 io_.init_bf16();
623 if (axis_simd_tail_) io_.prepare_tail_mask();
624
625#define PARAM_OFF(x) offsetof(ker_args_t, x)
626 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
627 mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
628 mov(reg_diff_scale, ptr[reg_param + PARAM_OFF(diff_scale)]);
629 mov(reg_diff_shift, ptr[reg_param + PARAM_OFF(diff_shift)]);
630 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
631 mov(reg_inv_sqrtvar, ptr[reg_param + PARAM_OFF(inv_sqrtvar)]);
632 mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]);
633#undef PARAM_OFF
634
635 // add block_start to block_size to define block_end
636 add(reg_block_end, reg_src);
637
638 Label unroll_loop, end;
639 L(unroll_loop);
640 {
641 cmp(reg_block_end, reg_src);
642 jle(end, T_NEAR);
643
644 uni_vmovss(xmm_tmp, dword[reg_mean]);
645 uni_vbroadcastss(vmm_mean, xmm_tmp);
646 uni_vmovss(xmm_tmp, dword[reg_inv_sqrtvar]);
647 uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp);
648
649 for (int i = 0; i < axis_simd_full_; i++)
650 calculate_diff_scale_shift(i * simd_w_);
651 if (axis_simd_tail_)
652 calculate_diff_scale_shift(axis_simd_full_ * simd_w_, true);
653
654 add(reg_src, c_src_size);
655 add(reg_diff_dst, c_ddst_size);
656 add(reg_mean, float_size);
657 add(reg_inv_sqrtvar, float_size);
658 jmp(unroll_loop);
659 }
660 L(end);
661
662 postamble();
663 }
664};
665
666diff_ss_kernel_t *diff_ss_kernel_t::create(const layer_normalization_pd_t *pd) {
667 if (mayiuse(avx512_core)) {
668 return new jit_diff_ss_kernel_t<avx512_core>(pd);
669 } else if (mayiuse(avx2)) {
670 return new jit_diff_ss_kernel_t<avx2>(pd);
671 } else {
672 assert(!"kernel is empty.");
673 return nullptr;
674 }
675}
676
677template <cpu_isa_t isa>
678struct jit_diff_data_base_kernel_t : diff_data_kernel_t, public jit_generator {
679 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_diff_data_kernel_t);
680
681 void operator()(const void *src, const void *diff_dst, void *diff_src,
682 const float *ss, const float *mean, float *const inv_sqrtvar,
683 const size_t block_size) const override {
684 ker_args_t args;
685 args.src = src;
686 args.diff_dst = diff_dst;
687 args.diff_src = diff_src;
688 args.ss = ss;
689 args.mean = mean;
690 args.inv_sqrtvar = inv_sqrtvar;
691 args.block_size
692 = block_size * C_ * types::data_type_size(src_d_.data_type());
693 jit_generator::operator()(&args);
694 }
695
696 status_t create_kernel() override { return jit_generator::create_kernel(); }
697
698 jit_diff_data_base_kernel_t(const layer_normalization_pd_t *pd)
699 : diff_data_kernel_t(pd)
700 , jit_generator(jit_name())
701 , src_d_(pd_->src_md())
702 , d_dst_d_(pd_->diff_dst_md())
703 , d_src_d_(pd_->diff_src_md())
704 , simd_w_(vlen / sizeof(float))
705 , C_(pd_->norm_axis())
706 , axis_simd_full_(C_ / simd_w_)
707 , axis_simd_tail_(C_ % simd_w_)
708 , use_scale_(pd_->use_scale())
709 , use_shift_(pd_->use_shift())
710 , calculate_diff_stats_(!pd_->stats_are_src()) {
711
712 io::io_conf_t io_conf;
713 io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_,
714 tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp);
715 io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx,
716 bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp,
717 bf16_emu_zmm_4_idx);
718 const auto io_isa = get_io_isa(isa,
719 utils::one_of(f16, src_d_.data_type(), d_dst_d_.data_type(),
720 d_src_d_.data_type()));
721 io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa,
722 {src_d_.data_type(), d_dst_d_.data_type(), d_src_d_.data_type(),
723 f32 /* stats */},
724 io_conf, io_tail_conf, io_bf16_conf);
725 }
726
727protected:
728 static constexpr int unroll_factor_ = 4;
729 using Vmm = typename cpu_isa_traits<isa>::Vmm;
730 const AddressFrame &vmmword
731 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
732 const int vlen = cpu_isa_traits<isa>::vlen;
733
734 struct ker_args_t {
735 const void *src;
736 const void *diff_dst;
737 void *diff_src;
738 const float *ss;
739 const float *mean;
740 const float *inv_sqrtvar;
741 size_t block_size;
742 };
743
744 io::jit_io_multi_dt_helper_t<Vmm> io_;
745 const memory_desc_wrapper src_d_, d_dst_d_, d_src_d_;
746 const size_t simd_w_;
747 const dim_t C_;
748 const dim_t axis_simd_full_;
749 const dim_t axis_simd_tail_;
750 const bool use_scale_;
751 const bool use_shift_;
752 const bool calculate_diff_stats_;
753
754 const Reg64 reg_param = abi_param1;
755 const Reg64 reg_src = rdx;
756 const Reg64 reg_diff_dst = rax;
757 const Reg64 reg_diff_src = r14;
758 const Reg64 reg_mean = rbx;
759 const Reg64 reg_inv_sqrtvar = r13;
760 const Reg64 reg_scale = r8;
761 const Reg64 reg_tmp = r11;
762 const Reg64 reg_dd_scale = r10;
763 const Reg64 reg_dd_scale_x = r12;
764 const Reg64 reg_block_end = r9;
765
766 const Vmm vmm_tail_mask = Vmm(0);
767 const Vmm vmm_C = Vmm(7);
768 const Vmm vmm_scale = Vmm(8);
769 const Xmm xmm_tmp = Xmm(9);
770 const Vmm vmm_tmp = Vmm(9);
771 const Vmm vmm_inv_sqrtvar = Vmm(10);
772 const Vmm vmm_dsrc = Vmm(11);
773 const Vmm vmm_dd_scale_x = Vmm(12);
774 const Vmm vmm_dd_scale = Vmm(13);
775 const Vmm vmm_src = Vmm(14);
776 const Vmm vmm_mean = Vmm(15);
777
778 const int bf16_emu_zmm_1_idx = 28;
779 const int bf16_emu_zmm_2_idx = 29;
780 const int bf16_emu_zmm_3_idx = 30;
781 const int bf16_emu_zmm_4_idx = 31;
782 const int tail_opmask_idx = 1;
783
784 Address src_ptr(size_t offt = 0) {
785 return vmmword[reg_src + offt * src_d_.data_type_size()];
786 }
787
788 Address d_dst_ptr(size_t offt = 0) {
789 return vmmword[reg_diff_dst + offt * d_dst_d_.data_type_size()];
790 }
791
792 Address d_src_ptr(size_t offt = 0) {
793 return vmmword[reg_diff_src + offt * d_src_d_.data_type_size()];
794 }
795
796 Address scale_ptr(size_t offt = 0) {
797 return vmmword[reg_scale + offt * sizeof(float)];
798 }
799
800 virtual void reduce(Vmm vmm_src, Vmm vmm_tmp) = 0;
801
802 void compute_dd_scales(size_t offt_elems, bool tail = false) {
803 Vmm vmm_ddst = vmm_dsrc;
804 io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail);
805 if (use_scale_) {
806 io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail);
807 uni_vmulps(vmm_ddst, vmm_ddst, vmm_scale);
808 }
809 io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail);
810
811 uni_vaddps(vmm_dd_scale, vmm_dd_scale, vmm_ddst);
812 uni_vsubps(vmm_src, vmm_src, vmm_mean);
813 uni_vfmadd231ps(vmm_dd_scale_x, vmm_ddst, vmm_src);
814 };
815
816 void compute_diff_src(size_t offt_elems, bool tail = false) {
817 Vmm vmm_ddst = vmm_dsrc;
818 io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail);
819 if (use_scale_) {
820 io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail);
821 uni_vmulps(vmm_dsrc, vmm_dsrc, vmm_scale);
822 }
823 if (calculate_diff_stats_) {
824 io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail);
825 uni_vsubps(vmm_src, vmm_src, vmm_mean);
826 uni_vmulps(vmm_src, vmm_src, vmm_inv_sqrtvar);
827 uni_vfmadd213ps(vmm_src, vmm_dd_scale_x, vmm_dd_scale);
828 uni_vdivps(vmm_src, vmm_src, vmm_C);
829 uni_vsubps(vmm_dsrc, vmm_dsrc, vmm_src);
830 }
831 uni_vmulps(vmm_dsrc, vmm_dsrc, vmm_inv_sqrtvar);
832 io_[d_src_d_.data_type()]->store(vmm_dsrc, d_src_ptr(offt_elems), tail);
833 };
834
835 void generate() override {
836 const size_t c_src_size
837 = C_ * types::data_type_size(src_d_.data_type());
838 const size_t c_ddst_size
839 = C_ * types::data_type_size(d_dst_d_.data_type());
840 const size_t c_dsrc_size
841 = C_ * types::data_type_size(d_src_d_.data_type());
842 static const size_t float_size = types::data_type_size(f32);
843
844 preamble();
845
846 io_.init_bf16();
847 if (axis_simd_tail_) io_.prepare_tail_mask();
848
849#define PARAM_OFF(x) offsetof(ker_args_t, x)
850 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
851 mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
852 mov(reg_diff_src, ptr[reg_param + PARAM_OFF(diff_src)]);
853 mov(reg_scale, ptr[reg_param + PARAM_OFF(ss)]);
854
855 if (calculate_diff_stats_)
856 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
857 mov(reg_inv_sqrtvar, ptr[reg_param + PARAM_OFF(inv_sqrtvar)]);
858 mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]);
859#undef PARAM_OFF
860
861 mov(reg_tmp, float2int(C_));
862 uni_vmovq(xmm_tmp, reg_tmp);
863 uni_vbroadcastss(vmm_C, xmm_tmp);
864
865 // add block_start to block_size to define block_end
866 add(reg_block_end, reg_src);
867
868 Label unroll_loop, end;
869 L(unroll_loop);
870 {
871 cmp(reg_block_end, reg_src);
872 jle(end, T_NEAR);
873
874 uni_vmovss(xmm_tmp, dword[reg_inv_sqrtvar]);
875 uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp);
876
877 if (calculate_diff_stats_) {
878 uni_vmovss(xmm_tmp, dword[reg_mean]);
879 uni_vbroadcastss(vmm_mean, xmm_tmp);
880
881 uni_vpxor(vmm_dd_scale, vmm_dd_scale, vmm_dd_scale);
882 uni_vpxor(vmm_dd_scale_x, vmm_dd_scale_x, vmm_dd_scale_x);
883
884 for (int i = 0; i < axis_simd_full_; i++)
885 compute_dd_scales(i * simd_w_);
886 if (axis_simd_tail_)
887 compute_dd_scales(axis_simd_full_ * simd_w_, true);
888
889 reduce(vmm_dd_scale, vmm_tmp);
890 reduce(vmm_dd_scale_x, vmm_tmp);
891 uni_vmulps(vmm_dd_scale_x, vmm_dd_scale_x, vmm_inv_sqrtvar);
892 }
893
894 for (int i = 0; i < axis_simd_full_; i++)
895 compute_diff_src(i * simd_w_);
896 if (axis_simd_tail_)
897 compute_diff_src(axis_simd_full_ * simd_w_, true);
898
899 add(reg_src, c_src_size);
900 add(reg_diff_dst, c_ddst_size);
901 add(reg_diff_src, c_dsrc_size);
902 if (calculate_diff_stats_) add(reg_mean, float_size);
903 add(reg_inv_sqrtvar, float_size);
904 jmp(unroll_loop);
905 }
906 L(end);
907
908 postamble();
909 }
910};
911
912template <cpu_isa_t isa>
913struct jit_diff_data_kernel_t;
914
915template <>
916struct jit_diff_data_kernel_t<avx512_core>
917 : public jit_diff_data_base_kernel_t<avx512_core> {
918
919 using jit_diff_data_base_kernel_t::jit_diff_data_base_kernel_t;
920
921 void reduce(Vmm vmm_src, Vmm vmm_tmp) override {
922 vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0x4E); // 256-bit shuffle
923 vaddps(vmm_src, vmm_src, vmm_tmp);
924 vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0xB1); // 128/256-bit shuffle
925 vaddps(vmm_src, vmm_src, vmm_tmp);
926 vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle
927 vaddps(vmm_src, vmm_src, vmm_tmp);
928 vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle
929 vaddps(vmm_src, vmm_src, vmm_tmp);
930 }
931};
932
933template <>
934struct jit_diff_data_kernel_t<avx2> : public jit_diff_data_base_kernel_t<avx2> {
935
936 using jit_diff_data_base_kernel_t::jit_diff_data_base_kernel_t;
937
938 void reduce(Vmm vmm_src, Vmm vmm_tmp) override {
939 vperm2f128(vmm_tmp, vmm_src, vmm_src, 0x1); // 128/256-bit shuffle
940 vaddps(vmm_src, vmm_src, vmm_tmp);
941 vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle
942 vaddps(vmm_src, vmm_src, vmm_tmp);
943 vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle
944 vaddps(vmm_src, vmm_src, vmm_tmp);
945 }
946};
947
948diff_data_kernel_t *diff_data_kernel_t::create(
949 const layer_normalization_pd_t *pd) {
950 if (mayiuse(avx512_core)) {
951 return new jit_diff_data_kernel_t<avx512_core>(pd);
952 } else if (mayiuse(avx2)) {
953 return new jit_diff_data_kernel_t<avx2>(pd);
954 } else {
955 assert(!"kernel is empty.");
956 return nullptr;
957 }
958}
959
960status_t jit_uni_layer_normalization_fwd_t::execute_forward(
961 const exec_ctx_t &ctx) const {
962 auto scratchpad = ctx.get_scratchpad_grantor();
963 const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
964 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
965
966 auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE);
967 auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT);
968
969 float *mean, *variance;
970 if (pd()->use_tmp_stats()) {
971 mean = scratchpad.template get<float>(key_lnorm_tmp_mean);
972 variance = scratchpad.template get<float>(key_lnorm_tmp_var);
973 } else {
974 mean = pd()->stats_are_src()
975 ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN))
976 : CTX_OUT_MEM(float *, DNNL_ARG_MEAN);
977 variance = pd()->stats_are_src()
978 ? const_cast<float *>(
979 CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE))
980 : CTX_OUT_MEM(float *, DNNL_ARG_VARIANCE);
981 }
982
983 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
984 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
985
986 const memory_desc_wrapper src_d(pd()->src_md());
987 const memory_desc_wrapper dst_d(pd()->dst_md());
988
989 const dim_t N = pd()->across_axis();
990 const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1];
991
992 parallel(0, [&](const int ithr, const int nthr) {
993 dim_t N_start = 0, N_end = 0;
994 balance211(N, nthr, ithr, N_start, N_end);
995 const char *const __restrict src_ptr
996 = reinterpret_cast<const char *>(src)
997 + N_start * C_padded * src_d.data_type_size();
998 char *const __restrict dst_ptr = reinterpret_cast<char *>(dst)
999 + N_start * C_padded * dst_d.data_type_size();
1000 const int block_size = N_end - N_start;
1001 (*stat_and_data_kernel_)(src_ptr, dst_ptr, scale, shift, &mean[N_start],
1002 &variance[N_start], src_scales, dst_scales, block_size);
1003 });
1004 return status::success;
1005}
1006
1007status_t jit_uni_layer_normalization_bwd_t::execute_backward(
1008 const exec_ctx_t &ctx) const {
1009 status_t status = status::success;
1010
1011 auto scratchpad = ctx.get_scratchpad_grantor();
1012 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
1013 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
1014 auto scale = CTX_IN_MEM(float *, DNNL_ARG_SCALE);
1015 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
1016
1017 auto diff_scale = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SCALE, status);
1018 CHECK(status);
1019 auto diff_shift = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SHIFT, status);
1020 CHECK(status);
1021
1022 const float *mean, *variance;
1023 if (pd()->use_tmp_stats()) {
1024 mean = scratchpad.template get<float>(key_lnorm_tmp_mean);
1025 variance = scratchpad.template get<float>(key_lnorm_tmp_var);
1026 } else {
1027 mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN);
1028 variance = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE);
1029 }
1030
1031 float *const inv_sqrtvar
1032 = scratchpad.template get<float>(key_lnorm_inv_sqrtvar);
1033
1034 const memory_desc_wrapper src_d(pd()->src_md());
1035 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1036 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
1037
1038 const dim_t N = pd()->across_axis();
1039 const dim_t C = pd()->norm_axis();
1040 const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1];
1041
1042 float *reduce = scratchpad.template get<float>(key_lnorm_reduction);
1043 if (diff_scale == nullptr)
1044 diff_scale = scratchpad.template get<float>(key_lnorm_tmp_diff_ss);
1045 if (diff_shift == nullptr) {
1046 diff_shift = scratchpad.template get<float>(key_lnorm_tmp_diff_ss);
1047 }
1048
1049 const int max_nthr = pd()->nthr_;
1050
1051 parallel(max_nthr, [&](int ithr, int nthr) {
1052 dim_t N_start = 0, N_end = 0;
1053 balance211(N, nthr, ithr, N_start, N_end);
1054 const int block_size = N_end - N_start;
1055 const char *const __restrict src_ptr
1056 = reinterpret_cast<const char *>(src)
1057 + N_start * C_padded * src_d.data_type_size();
1058 const char *const __restrict diff_dst_ptr
1059 = reinterpret_cast<const char *>(diff_dst)
1060 + N_start * C_padded * diff_dst_d.data_type_size();
1061
1062 float *my_diff_gamma = reduce + C * ithr;
1063 float *my_diff_beta = reduce + C * nthr + C * ithr;
1064 for (dim_t c = 0; c < C; c++) {
1065 my_diff_gamma[c] = 0.;
1066 my_diff_beta[c] = 0.;
1067 }
1068 (*diff_ss_kernel_)(src_ptr, diff_dst_ptr, my_diff_gamma, my_diff_beta,
1069 &mean[N_start], &variance[N_start], &inv_sqrtvar[N_start],
1070 block_size);
1071 });
1072
1073 parallel_nd(C, [&](dim_t c) {
1074 float diff_gamma = 0, diff_beta = 0;
1075 for (dim_t n = 0; n < max_nthr; n++) {
1076 diff_gamma += reduce[C * n + c];
1077 diff_beta += reduce[C * max_nthr + C * n + c];
1078 }
1079 diff_scale[c] = diff_gamma;
1080 diff_shift[c] = diff_beta;
1081 });
1082
1083 parallel(max_nthr, [&](int ithr, int nthr) {
1084 dim_t N_start = 0, N_end = 0;
1085 balance211(N, nthr, ithr, N_start, N_end);
1086 const int block_size = N_end - N_start;
1087 const char *const __restrict src_ptr
1088 = reinterpret_cast<const char *>(src)
1089 + N_start * C_padded * src_d.data_type_size();
1090 const char *const __restrict diff_dst_ptr
1091 = reinterpret_cast<const char *>(diff_dst)
1092 + N_start * C_padded * diff_dst_d.data_type_size();
1093 char *const __restrict diff_src_ptr = reinterpret_cast<char *>(diff_src)
1094 + N_start * C_padded * diff_src_d.data_type_size();
1095
1096 (*diff_data_kernel_)(src_ptr, diff_dst_ptr, diff_src_ptr, scale,
1097 &mean[N_start], &inv_sqrtvar[N_start], block_size);
1098 });
1099 return status::success;
1100}
1101
1102} // namespace x64
1103} // namespace cpu
1104} // namespace impl
1105} // namespace dnnl
1106