1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/reorder.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/cpu_primitive.hpp" |
26 | |
27 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_uni_layer_normalization.hpp" |
30 | #include "cpu/x64/utils/jit_io_helper.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | using namespace memory_tracking::names; |
38 | using namespace data_type; |
39 | using namespace Xbyak; |
40 | |
41 | cpu_isa_t get_io_isa(cpu_isa_t isa, bool has_f16) { |
42 | // re-using avx512_core instantiation for f16 |
43 | return has_f16 && is_superset(isa, avx512_core) ? avx512_core_fp16 : isa; |
44 | } |
45 | |
46 | template <cpu_isa_t isa> |
47 | struct jit_stat_and_data_base_kernel_t : stat_and_data_kernel_t, |
48 | public jit_generator { |
49 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_stat_and_data_kernel_t); |
50 | |
51 | void operator()(const void *src, void *dst, const float *scale, |
52 | const float *shift, float *mean, float *var, |
53 | const float *src_scales, const float *dst_scales, |
54 | const size_t block_size) const override { |
55 | ker_args_t args; |
56 | args.src = src; |
57 | args.dst = dst; |
58 | args.scale = scale; |
59 | args.shift = shift; |
60 | args.mean = mean; |
61 | args.var = var; |
62 | args.src_scales = src_scales; |
63 | args.dst_scales = dst_scales; |
64 | args.block_size |
65 | = block_size * C_ * types::data_type_size(src_d_.data_type()); |
66 | args.eps = eps_; |
67 | jit_generator::operator()(&args); |
68 | } |
69 | |
70 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
71 | |
72 | jit_stat_and_data_base_kernel_t(const layer_normalization_pd_t *pd) |
73 | : stat_and_data_kernel_t(pd) |
74 | , jit_generator(jit_name()) |
75 | , src_d_(pd_->src_md()) |
76 | , dst_d_(pd_->dst_md()) |
77 | , simd_w_(vlen / sizeof(float)) |
78 | , C_(pd_->norm_axis()) |
79 | , axis_simd_full_(C_ / simd_w_) |
80 | , axis_simd_tail_(C_ % simd_w_) |
81 | , use_scale_(pd_->use_scale()) |
82 | , use_shift_(pd_->use_shift()) |
83 | , save_stats_(pd_->is_training()) |
84 | , calculate_stats_(!pd_->stats_are_src()) |
85 | , eps_(pd_->desc()->layer_norm_epsilon) { |
86 | |
87 | io::io_conf_t io_conf; |
88 | io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_, |
89 | tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp); |
90 | io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx, |
91 | bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp, |
92 | bf16_emu_zmm_4_idx); |
93 | io::io_saturation_conf_t io_saturation_conf( |
94 | vmm_zero.getIdx(), vmm_saturation_ubound.getIdx(), reg_tmp); |
95 | const auto io_isa = get_io_isa(isa, |
96 | utils::one_of(f16, src_d_.data_type(), dst_d_.data_type())); |
97 | io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa, |
98 | {src_d_.data_type(), dst_d_.data_type(), f32 /* stats */}, |
99 | io_conf, io_tail_conf, io_bf16_conf, |
100 | {{dst_d_.data_type(), io_saturation_conf}}); |
101 | } |
102 | |
103 | protected: |
104 | static constexpr int unroll_factor_ = 4; |
105 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
106 | const AddressFrame &vmmword |
107 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
108 | const int vlen = cpu_isa_traits<isa>::vlen; |
109 | |
110 | struct ker_args_t { |
111 | const void *src; |
112 | void *dst; |
113 | const float *scale; |
114 | const float *shift; |
115 | const float *mean; |
116 | const float *var; |
117 | const float *src_scales; |
118 | const float *dst_scales; |
119 | size_t block_size; |
120 | float eps; |
121 | }; |
122 | |
123 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
124 | const memory_desc_wrapper src_d_, dst_d_; |
125 | const size_t simd_w_; |
126 | const dim_t C_; |
127 | const dim_t axis_simd_full_; |
128 | const dim_t axis_simd_tail_; |
129 | const bool use_scale_; |
130 | const bool use_shift_; |
131 | const bool save_stats_; |
132 | const bool calculate_stats_; |
133 | const float eps_; |
134 | |
135 | const Reg64 reg_param = abi_param1; |
136 | const Reg64 reg_src = rdx; |
137 | const Reg64 reg_dst = rax; |
138 | const Reg64 reg_mean = rbx; |
139 | const Reg64 reg_scale = r8; |
140 | const Reg64 reg_block_end = r9; |
141 | const Reg64 reg_eps = r10; |
142 | const Reg64 reg_tmp = r11; |
143 | const Reg64 reg_shift = r12; |
144 | const Reg64 reg_var = r13; |
145 | const Reg64 reg_src_scales = r14; |
146 | const Reg64 reg_dst_scales = r15; |
147 | |
148 | const Vmm vmm_tail_mask = Vmm(0); |
149 | const Vmm vmm_zero = Vmm(4); // In unroll range, safe for dst compute. |
150 | const Vmm vmm_saturation_ubound |
151 | = Vmm(5); // In unroll range, safe for dst compute. |
152 | const Vmm vmm_combined_scales |
153 | = Vmm(6); // In unroll range, safe for dst compute. |
154 | const Vmm vmm_scale = Vmm(7); // In unroll range, safe for dst compute. |
155 | const Vmm vmm_shift = Vmm(8); // In unroll range, safe for dst compute. |
156 | const Vmm vmm_ones = Vmm(9); |
157 | const Vmm vmm_eps = Vmm(10); |
158 | const Vmm vmm_c = Vmm(11); |
159 | const Vmm vmm_mean = Vmm(12); |
160 | const Vmm vmm_inv_sqrtvar = Vmm(13); |
161 | const Vmm vmm_dst = Vmm(14); |
162 | const Vmm vmm_tmp = Vmm(15); |
163 | const Xmm xmm_tmp = Xmm(15); |
164 | |
165 | const int bf16_emu_zmm_1_idx = 28; |
166 | const int bf16_emu_zmm_2_idx = 29; |
167 | const int bf16_emu_zmm_3_idx = 30; |
168 | const int bf16_emu_zmm_4_idx = 31; |
169 | const int tail_opmask_idx = 1; |
170 | |
171 | Address src_ptr(size_t offt = 0) { |
172 | return vmmword[reg_src + offt * src_d_.data_type_size()]; |
173 | } |
174 | |
175 | Address dst_ptr(size_t offt = 0) { |
176 | return vmmword[reg_dst + offt * dst_d_.data_type_size()]; |
177 | } |
178 | |
179 | Address mean_ptr(size_t offt = 0) { |
180 | return vmmword[reg_mean + offt * sizeof(float)]; |
181 | } |
182 | |
183 | Address var_ptr(size_t offt = 0) { |
184 | return vmmword[reg_var + offt * sizeof(float)]; |
185 | } |
186 | |
187 | Address scale_ptr(size_t offt = 0) { |
188 | return vmmword[reg_scale + offt * sizeof(float)]; |
189 | } |
190 | |
191 | Address shift_ptr(size_t offt = 0) { |
192 | return vmmword[reg_shift + offt * sizeof(float)]; |
193 | } |
194 | |
195 | virtual void compute_var() = 0; |
196 | virtual void reduce(Vmm vmm_src, Vmm vmm_tmp) = 0; |
197 | |
198 | template <typename F> |
199 | void compute(Vmm vmm_stat, F op) { |
200 | bool need_tail = false; |
201 | int base_idx = 1; // Preserve `0` for tail on AVX2. |
202 | |
203 | uni_vpxor(Vmm(base_idx), Vmm(base_idx), Vmm(base_idx)); |
204 | if (axis_simd_full_ > 0) { |
205 | const int unroll |
206 | = axis_simd_full_ >= unroll_factor_ ? unroll_factor_ : 1; |
207 | assert(math::is_pow2(unroll)); |
208 | |
209 | for (int i = base_idx + 1; i < base_idx + unroll; i++) |
210 | uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); |
211 | |
212 | // unrolled loop |
213 | for (int i = 0; i < axis_simd_full_ / unroll; i++) |
214 | for (int j = base_idx; j < base_idx + unroll; j++) { |
215 | io_[src_d_.data_type()]->load( |
216 | src_ptr((i * unroll + j - base_idx) * simd_w_), |
217 | Vmm(j + unroll), need_tail); |
218 | op(Vmm(j), Vmm(j + unroll), need_tail); |
219 | } |
220 | |
221 | // unrolled loop reduction |
222 | int n = unroll; |
223 | while (n > 1) { |
224 | for (int j = base_idx; j < base_idx + n / 2; j++) |
225 | uni_vaddps(Vmm(j), Vmm(j), Vmm(j + n / 2)); |
226 | n = n / 2; |
227 | } |
228 | |
229 | // unrolled loop remainder |
230 | for (int i = utils::rnd_dn(axis_simd_full_, unroll); |
231 | i < axis_simd_full_; i++) { |
232 | io_[src_d_.data_type()]->load( |
233 | src_ptr(i * simd_w_), Vmm(base_idx + 1), need_tail); |
234 | op(Vmm(base_idx), Vmm(base_idx + 1), need_tail); |
235 | } |
236 | } |
237 | |
238 | if (axis_simd_tail_ > 0) { |
239 | need_tail = true; |
240 | // vector remainder |
241 | io_[src_d_.data_type()]->load(src_ptr(axis_simd_full_ * simd_w_), |
242 | Vmm(base_idx + 1), need_tail); |
243 | op(Vmm(base_idx), Vmm(base_idx + 1), need_tail); |
244 | } |
245 | |
246 | reduce(Vmm(base_idx), Vmm(base_idx + 1)); |
247 | uni_vdivps(Vmm(base_idx), Vmm(base_idx), vmm_c, vmm_tmp); |
248 | uni_vmovups(vmm_stat, Vmm(base_idx)); |
249 | } |
250 | |
251 | void compute_mean() { |
252 | compute(vmm_mean, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) { |
253 | uni_vaddps(vmm_dst, vmm_dst, vmm_src); |
254 | }); |
255 | if (save_stats_) uni_vmovss(ptr[reg_mean], Xmm(vmm_mean.getIdx())); |
256 | } |
257 | |
258 | void calculate_dst(size_t offt_elems, bool tail = false) { |
259 | if (use_scale_) { |
260 | io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail); |
261 | } |
262 | if (use_shift_) { |
263 | io_[f32]->load(shift_ptr(offt_elems), vmm_shift, tail); |
264 | } |
265 | io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_dst, tail); |
266 | uni_vsubps(vmm_dst, vmm_dst, vmm_mean); |
267 | uni_vmulps(vmm_dst, vmm_dst, vmm_inv_sqrtvar); |
268 | if (use_scale_ && use_shift_) |
269 | uni_vfmadd213ps(vmm_dst, vmm_scale, vmm_shift); |
270 | else { |
271 | if (use_scale_) uni_vmulps(vmm_dst, vmm_dst, vmm_scale); |
272 | if (use_shift_) uni_vaddps(vmm_dst, vmm_dst, vmm_shift); |
273 | } |
274 | uni_vmulps(vmm_dst, vmm_dst, vmm_combined_scales); |
275 | io_[dst_d_.data_type()]->store(vmm_dst, dst_ptr(offt_elems), tail); |
276 | } |
277 | |
278 | void generate() override { |
279 | const size_t c_src_size |
280 | = C_ * types::data_type_size(src_d_.data_type()); |
281 | const size_t c_dst_size |
282 | = C_ * types::data_type_size(dst_d_.data_type()); |
283 | static const size_t float_size = types::data_type_size(f32); |
284 | |
285 | preamble(); |
286 | |
287 | io_.init_bf16(); |
288 | if (axis_simd_tail_) io_.prepare_tail_mask(); |
289 | |
290 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
291 | mov(reg_src, ptr[reg_param + PARAM_OFF(src)]); |
292 | mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); |
293 | mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]); |
294 | mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]); |
295 | mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); |
296 | mov(reg_var, ptr[reg_param + PARAM_OFF(var)]); |
297 | mov(reg_src_scales, ptr[reg_param + PARAM_OFF(src_scales)]); |
298 | mov(reg_dst_scales, ptr[reg_param + PARAM_OFF(dst_scales)]); |
299 | mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]); |
300 | mov(reg_eps, ptr[reg_param + PARAM_OFF(eps)]); |
301 | #undef PARAM_OFF |
302 | |
303 | uni_vmovq(xmm_tmp, reg_eps); |
304 | uni_vbroadcastss(vmm_eps, xmm_tmp); |
305 | mov(reg_tmp, float2int(1.f)); |
306 | uni_vmovq(xmm_tmp, reg_tmp); |
307 | uni_vbroadcastss(vmm_ones, xmm_tmp); |
308 | mov(reg_tmp, float2int(C_)); |
309 | uni_vmovq(xmm_tmp, reg_tmp); |
310 | uni_vbroadcastss(vmm_c, xmm_tmp); |
311 | |
312 | // add block_start to block_size to define block_end |
313 | add(reg_block_end, reg_src); |
314 | |
315 | Label unroll_loop, end; |
316 | L(unroll_loop); |
317 | { |
318 | cmp(reg_block_end, reg_src); |
319 | jle(end, T_NEAR); |
320 | |
321 | if (calculate_stats_) { |
322 | // compute stats |
323 | compute_mean(); |
324 | compute_var(); |
325 | } else { |
326 | // read mean and var from input |
327 | uni_vmovss(xmm_tmp, dword[reg_mean]); |
328 | uni_vbroadcastss(vmm_mean, xmm_tmp); |
329 | uni_vmovss(xmm_tmp, dword[reg_var]); |
330 | uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp); |
331 | } |
332 | |
333 | // calculate inv_sqrtvar |
334 | uni_vaddps(vmm_inv_sqrtvar, vmm_inv_sqrtvar, vmm_eps); |
335 | uni_vsqrtps(vmm_inv_sqrtvar, vmm_inv_sqrtvar); |
336 | uni_vdivps(vmm_inv_sqrtvar, vmm_ones, vmm_inv_sqrtvar, vmm_tmp); |
337 | |
338 | // precompute and broadcast scales (in case of runtime) |
339 | uni_vmovss(xmm_tmp, dword[reg_src_scales]); |
340 | uni_vbroadcastss(vmm_combined_scales, xmm_tmp); |
341 | uni_vmovss(xmm_tmp, dword[reg_dst_scales]); |
342 | uni_vbroadcastss(vmm_tmp, xmm_tmp); |
343 | uni_vmulps(vmm_combined_scales, vmm_combined_scales, vmm_tmp); |
344 | io_.init_saturate_f32({dst_d_.data_type()}); |
345 | |
346 | // calculate dst |
347 | for (int i = 0; i < axis_simd_full_; i++) |
348 | calculate_dst(i * simd_w_); |
349 | if (axis_simd_tail_) calculate_dst(axis_simd_full_ * simd_w_, true); |
350 | |
351 | add(reg_src, c_src_size); |
352 | add(reg_dst, c_dst_size); |
353 | add(reg_mean, float_size); |
354 | add(reg_var, float_size); |
355 | jmp(unroll_loop); |
356 | } |
357 | L(end); |
358 | |
359 | postamble(); |
360 | } |
361 | }; |
362 | |
363 | template <cpu_isa_t isa> |
364 | struct jit_stat_and_data_kernel_t; |
365 | |
366 | template <> |
367 | struct jit_stat_and_data_kernel_t<avx512_core> |
368 | : public jit_stat_and_data_base_kernel_t<avx512_core> { |
369 | |
370 | using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t; |
371 | |
372 | void compute_var() override { |
373 | compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) { |
374 | // Need to preserve zeros after subtract for correct answer. |
375 | if (!need_tail) |
376 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
377 | else |
378 | uni_vsubps(vmm_src | Opmask(tail_opmask_idx) | T_z, vmm_src, |
379 | vmm_mean); |
380 | uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src); |
381 | }); |
382 | if (save_stats_) |
383 | uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx())); |
384 | } |
385 | |
386 | void reduce(Vmm vmm_src, Vmm vmm_tmp) override { |
387 | vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0x4E); // 256-bit shuffle |
388 | vaddps(vmm_src, vmm_src, vmm_tmp); |
389 | vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0xB1); // 128/256-bit shuffle |
390 | vaddps(vmm_src, vmm_src, vmm_tmp); |
391 | vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle |
392 | vaddps(vmm_src, vmm_src, vmm_tmp); |
393 | vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle |
394 | vaddps(vmm_src, vmm_src, vmm_tmp); |
395 | } |
396 | }; |
397 | |
398 | template <> |
399 | struct jit_stat_and_data_kernel_t<avx2> |
400 | : jit_stat_and_data_base_kernel_t<avx2> { |
401 | |
402 | using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t; |
403 | |
404 | void compute_var() override { |
405 | compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) { |
406 | // Need to preserve zeros after subtract for correct answer. |
407 | if (!need_tail) |
408 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
409 | else { |
410 | // We need to call tail version once, it's fine to use `vmm_tmp` |
411 | uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); |
412 | uni_vblendvps(vmm_tmp, vmm_tmp, vmm_mean, vmm_tail_mask); |
413 | uni_vsubps(vmm_src, vmm_src, vmm_tmp); |
414 | } |
415 | uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src); |
416 | }); |
417 | if (save_stats_) |
418 | uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx())); |
419 | } |
420 | |
421 | void reduce(Vmm vmm_src, Vmm vmm_tmp) override { |
422 | vperm2f128(vmm_tmp, vmm_src, vmm_src, 0x1); // 128/256-bit shuffle |
423 | vaddps(vmm_src, vmm_src, vmm_tmp); |
424 | vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle |
425 | vaddps(vmm_src, vmm_src, vmm_tmp); |
426 | vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle |
427 | vaddps(vmm_src, vmm_src, vmm_tmp); |
428 | } |
429 | }; |
430 | |
431 | template <> |
432 | struct jit_stat_and_data_kernel_t<sse41> |
433 | : jit_stat_and_data_base_kernel_t<sse41> { |
434 | |
435 | using jit_stat_and_data_base_kernel_t::jit_stat_and_data_base_kernel_t; |
436 | |
437 | void compute_var() override { |
438 | compute(vmm_inv_sqrtvar, [&](Vmm vmm_dst, Vmm vmm_src, bool need_tail) { |
439 | // Need to preserve zeros after subtract for correct answer. |
440 | if (!need_tail) |
441 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
442 | else { |
443 | // We need to call tail version once, it's fine to use `vmm_tmp` |
444 | uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); |
445 | uni_vblendvps(vmm_tmp, vmm_tmp, vmm_mean, vmm_tail_mask); |
446 | uni_vsubps(vmm_src, vmm_src, vmm_tmp); |
447 | } |
448 | uni_vfmadd231ps(vmm_dst, vmm_src, vmm_src); |
449 | }); |
450 | if (save_stats_) |
451 | uni_vmovss(ptr[reg_var], Xmm(vmm_inv_sqrtvar.getIdx())); |
452 | } |
453 | |
454 | void reduce(Vmm vmm_src, Vmm vmm_tmp) override { |
455 | uni_vmovups(vmm_tmp, vmm_src); |
456 | shufps(vmm_tmp, vmm_tmp, 0x4E); // 64/128-bit shuffle |
457 | uni_vaddps(vmm_src, vmm_src, vmm_tmp); |
458 | uni_vmovups(vmm_tmp, vmm_src); |
459 | shufps(vmm_tmp, vmm_tmp, 0xB1); // 32/64-bit shuffle |
460 | uni_vaddps(vmm_src, vmm_src, vmm_tmp); |
461 | } |
462 | }; |
463 | |
464 | stat_and_data_kernel_t *stat_and_data_kernel_t::create( |
465 | const layer_normalization_pd_t *pd) { |
466 | if (mayiuse(avx512_core)) { |
467 | return new jit_stat_and_data_kernel_t<avx512_core>(pd); |
468 | } else if (mayiuse(avx2)) { |
469 | return new jit_stat_and_data_kernel_t<avx2>(pd); |
470 | } else if (mayiuse(sse41)) { |
471 | return new jit_stat_and_data_kernel_t<sse41>(pd); |
472 | } else { |
473 | assert(!"kernel is empty." ); |
474 | return nullptr; |
475 | } |
476 | } |
477 | |
478 | template <cpu_isa_t isa> |
479 | struct jit_diff_ss_kernel_t : diff_ss_kernel_t, public jit_generator { |
480 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_diff_ss_kernel_t); |
481 | |
482 | void operator()(const void *src, const void *diff_dst, float *diff_scale, |
483 | float *diff_shift, const float *mean, const float *var, |
484 | float *const inv_sqrtvar, const size_t block_size) const override { |
485 | ker_args_t args; |
486 | args.src = src; |
487 | args.diff_dst = diff_dst; |
488 | args.diff_scale = diff_scale; |
489 | args.diff_shift = diff_shift; |
490 | args.mean = mean; |
491 | for (size_t i = 0; i < block_size; i++) { |
492 | #ifdef __INTEL_COMPILER |
493 | //Without volatile ICC with -O2 & -O3 optimizes out denominator from |
494 | //inv_sqrtvar and computes 1/denom with lower precision |
495 | const volatile float denom = sqrtf(var[i] + eps_); |
496 | #else |
497 | const float denom = sqrtf(var[i] + eps_); |
498 | #endif |
499 | inv_sqrtvar[i] = 1.f / denom; |
500 | } |
501 | args.inv_sqrtvar = inv_sqrtvar; |
502 | args.block_size |
503 | = block_size * C_ * types::data_type_size(src_d_.data_type()); |
504 | jit_generator::operator()(&args); |
505 | } |
506 | |
507 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
508 | |
509 | jit_diff_ss_kernel_t(const layer_normalization_pd_t *pd) |
510 | : diff_ss_kernel_t(pd) |
511 | , jit_generator(jit_name()) |
512 | , src_d_(pd_->src_md()) |
513 | , d_dst_d_(pd_->diff_dst_md()) |
514 | , simd_w_(vlen / sizeof(float)) |
515 | , C_(pd_->norm_axis()) |
516 | , axis_simd_full_(C_ / simd_w_) |
517 | , axis_simd_tail_(C_ % simd_w_) |
518 | , eps_(pd_->desc()->layer_norm_epsilon) { |
519 | |
520 | io::io_conf_t io_conf; |
521 | io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_, |
522 | tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp); |
523 | io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx, |
524 | bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp, |
525 | bf16_emu_zmm_4_idx); |
526 | const auto io_isa = get_io_isa(isa, |
527 | utils::one_of(f16, src_d_.data_type(), d_dst_d_.data_type())); |
528 | io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa, |
529 | {src_d_.data_type(), d_dst_d_.data_type(), f32 /* stats */}, |
530 | io_conf, io_tail_conf, io_bf16_conf); |
531 | } |
532 | |
533 | protected: |
534 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
535 | const AddressFrame &vmmword |
536 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
537 | const int vlen = cpu_isa_traits<isa>::vlen; |
538 | |
539 | struct ker_args_t { |
540 | const void *src; |
541 | const void *diff_dst; |
542 | float *diff_scale; |
543 | float *diff_shift; |
544 | const float *mean; |
545 | const float *inv_sqrtvar; |
546 | size_t block_size; |
547 | }; |
548 | |
549 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
550 | const memory_desc_wrapper src_d_, d_dst_d_; |
551 | const size_t simd_w_; |
552 | const dim_t C_; |
553 | const dim_t axis_simd_full_; |
554 | const dim_t axis_simd_tail_; |
555 | const float eps_; |
556 | |
557 | const Reg64 reg_param = abi_param1; |
558 | const Reg64 reg_src = rdx; |
559 | const Reg64 reg_diff_dst = rax; |
560 | const Reg64 reg_mean = rbx; |
561 | const Reg64 reg_diff_scale = r8; |
562 | const Reg64 reg_block_end = r9; |
563 | const Reg64 reg_tmp = r11; |
564 | const Reg64 reg_diff_shift = r12; |
565 | const Reg64 reg_inv_sqrtvar = r13; |
566 | |
567 | const Vmm vmm_tail_mask = Vmm(0); |
568 | const Xmm xmm_tmp = Xmm(9); |
569 | const Vmm vmm_inv_sqrtvar = Vmm(10); |
570 | const Vmm vmm_ddst = Vmm(11); |
571 | const Vmm vmm_dscale = Vmm(12); |
572 | const Vmm vmm_dshift = Vmm(13); |
573 | const Vmm vmm_src = Vmm(14); |
574 | const Vmm vmm_mean = Vmm(15); |
575 | |
576 | const int bf16_emu_zmm_1_idx = 28; |
577 | const int bf16_emu_zmm_2_idx = 29; |
578 | const int bf16_emu_zmm_3_idx = 30; |
579 | const int bf16_emu_zmm_4_idx = 31; |
580 | const int tail_opmask_idx = 1; |
581 | |
582 | Address src_ptr(size_t offt = 0) { |
583 | return vmmword[reg_src + offt * src_d_.data_type_size()]; |
584 | } |
585 | |
586 | Address d_dst_ptr(size_t offt = 0) { |
587 | return vmmword[reg_diff_dst + offt * d_dst_d_.data_type_size()]; |
588 | } |
589 | |
590 | Address d_scale_ptr(size_t offt = 0) { |
591 | return vmmword[reg_diff_scale + offt * sizeof(float)]; |
592 | } |
593 | |
594 | Address d_shift_ptr(size_t offt = 0) { |
595 | return vmmword[reg_diff_shift + offt * sizeof(float)]; |
596 | } |
597 | |
598 | void calculate_diff_scale_shift(size_t offt_elems, bool tail = false) { |
599 | io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail); |
600 | io_[f32]->load(d_scale_ptr(offt_elems), vmm_dscale, tail); |
601 | io_[f32]->load(d_shift_ptr(offt_elems), vmm_dshift, tail); |
602 | io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail); |
603 | |
604 | uni_vaddps(vmm_dshift, vmm_dshift, vmm_ddst); |
605 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
606 | uni_vmulps(vmm_src, vmm_src, vmm_inv_sqrtvar); |
607 | uni_vfmadd231ps(vmm_dscale, vmm_src, vmm_ddst); |
608 | |
609 | io_[f32]->store(vmm_dscale, d_scale_ptr(offt_elems), tail); |
610 | io_[f32]->store(vmm_dshift, d_shift_ptr(offt_elems), tail); |
611 | }; |
612 | |
613 | void generate() override { |
614 | const size_t c_src_size |
615 | = C_ * types::data_type_size(src_d_.data_type()); |
616 | const size_t c_ddst_size |
617 | = C_ * types::data_type_size(d_dst_d_.data_type()); |
618 | static const size_t float_size = types::data_type_size(f32); |
619 | |
620 | preamble(); |
621 | |
622 | io_.init_bf16(); |
623 | if (axis_simd_tail_) io_.prepare_tail_mask(); |
624 | |
625 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
626 | mov(reg_src, ptr[reg_param + PARAM_OFF(src)]); |
627 | mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]); |
628 | mov(reg_diff_scale, ptr[reg_param + PARAM_OFF(diff_scale)]); |
629 | mov(reg_diff_shift, ptr[reg_param + PARAM_OFF(diff_shift)]); |
630 | mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); |
631 | mov(reg_inv_sqrtvar, ptr[reg_param + PARAM_OFF(inv_sqrtvar)]); |
632 | mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]); |
633 | #undef PARAM_OFF |
634 | |
635 | // add block_start to block_size to define block_end |
636 | add(reg_block_end, reg_src); |
637 | |
638 | Label unroll_loop, end; |
639 | L(unroll_loop); |
640 | { |
641 | cmp(reg_block_end, reg_src); |
642 | jle(end, T_NEAR); |
643 | |
644 | uni_vmovss(xmm_tmp, dword[reg_mean]); |
645 | uni_vbroadcastss(vmm_mean, xmm_tmp); |
646 | uni_vmovss(xmm_tmp, dword[reg_inv_sqrtvar]); |
647 | uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp); |
648 | |
649 | for (int i = 0; i < axis_simd_full_; i++) |
650 | calculate_diff_scale_shift(i * simd_w_); |
651 | if (axis_simd_tail_) |
652 | calculate_diff_scale_shift(axis_simd_full_ * simd_w_, true); |
653 | |
654 | add(reg_src, c_src_size); |
655 | add(reg_diff_dst, c_ddst_size); |
656 | add(reg_mean, float_size); |
657 | add(reg_inv_sqrtvar, float_size); |
658 | jmp(unroll_loop); |
659 | } |
660 | L(end); |
661 | |
662 | postamble(); |
663 | } |
664 | }; |
665 | |
666 | diff_ss_kernel_t *diff_ss_kernel_t::create(const layer_normalization_pd_t *pd) { |
667 | if (mayiuse(avx512_core)) { |
668 | return new jit_diff_ss_kernel_t<avx512_core>(pd); |
669 | } else if (mayiuse(avx2)) { |
670 | return new jit_diff_ss_kernel_t<avx2>(pd); |
671 | } else { |
672 | assert(!"kernel is empty." ); |
673 | return nullptr; |
674 | } |
675 | } |
676 | |
677 | template <cpu_isa_t isa> |
678 | struct jit_diff_data_base_kernel_t : diff_data_kernel_t, public jit_generator { |
679 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_lnorm_diff_data_kernel_t); |
680 | |
681 | void operator()(const void *src, const void *diff_dst, void *diff_src, |
682 | const float *ss, const float *mean, float *const inv_sqrtvar, |
683 | const size_t block_size) const override { |
684 | ker_args_t args; |
685 | args.src = src; |
686 | args.diff_dst = diff_dst; |
687 | args.diff_src = diff_src; |
688 | args.ss = ss; |
689 | args.mean = mean; |
690 | args.inv_sqrtvar = inv_sqrtvar; |
691 | args.block_size |
692 | = block_size * C_ * types::data_type_size(src_d_.data_type()); |
693 | jit_generator::operator()(&args); |
694 | } |
695 | |
696 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
697 | |
698 | jit_diff_data_base_kernel_t(const layer_normalization_pd_t *pd) |
699 | : diff_data_kernel_t(pd) |
700 | , jit_generator(jit_name()) |
701 | , src_d_(pd_->src_md()) |
702 | , d_dst_d_(pd_->diff_dst_md()) |
703 | , d_src_d_(pd_->diff_src_md()) |
704 | , simd_w_(vlen / sizeof(float)) |
705 | , C_(pd_->norm_axis()) |
706 | , axis_simd_full_(C_ / simd_w_) |
707 | , axis_simd_tail_(C_ % simd_w_) |
708 | , use_scale_(pd_->use_scale()) |
709 | , use_shift_(pd_->use_shift()) |
710 | , calculate_diff_stats_(!pd_->stats_are_src()) { |
711 | |
712 | io::io_conf_t io_conf; |
713 | io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_, |
714 | tail_opmask_idx, vmm_tail_mask.getIdx(), reg_tmp); |
715 | io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx, |
716 | bf16_emu_zmm_2_idx, bf16_emu_zmm_3_idx, reg_tmp, |
717 | bf16_emu_zmm_4_idx); |
718 | const auto io_isa = get_io_isa(isa, |
719 | utils::one_of(f16, src_d_.data_type(), d_dst_d_.data_type(), |
720 | d_src_d_.data_type())); |
721 | io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, io_isa, |
722 | {src_d_.data_type(), d_dst_d_.data_type(), d_src_d_.data_type(), |
723 | f32 /* stats */}, |
724 | io_conf, io_tail_conf, io_bf16_conf); |
725 | } |
726 | |
727 | protected: |
728 | static constexpr int unroll_factor_ = 4; |
729 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
730 | const AddressFrame &vmmword |
731 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
732 | const int vlen = cpu_isa_traits<isa>::vlen; |
733 | |
734 | struct ker_args_t { |
735 | const void *src; |
736 | const void *diff_dst; |
737 | void *diff_src; |
738 | const float *ss; |
739 | const float *mean; |
740 | const float *inv_sqrtvar; |
741 | size_t block_size; |
742 | }; |
743 | |
744 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
745 | const memory_desc_wrapper src_d_, d_dst_d_, d_src_d_; |
746 | const size_t simd_w_; |
747 | const dim_t C_; |
748 | const dim_t axis_simd_full_; |
749 | const dim_t axis_simd_tail_; |
750 | const bool use_scale_; |
751 | const bool use_shift_; |
752 | const bool calculate_diff_stats_; |
753 | |
754 | const Reg64 reg_param = abi_param1; |
755 | const Reg64 reg_src = rdx; |
756 | const Reg64 reg_diff_dst = rax; |
757 | const Reg64 reg_diff_src = r14; |
758 | const Reg64 reg_mean = rbx; |
759 | const Reg64 reg_inv_sqrtvar = r13; |
760 | const Reg64 reg_scale = r8; |
761 | const Reg64 reg_tmp = r11; |
762 | const Reg64 reg_dd_scale = r10; |
763 | const Reg64 reg_dd_scale_x = r12; |
764 | const Reg64 reg_block_end = r9; |
765 | |
766 | const Vmm vmm_tail_mask = Vmm(0); |
767 | const Vmm vmm_C = Vmm(7); |
768 | const Vmm vmm_scale = Vmm(8); |
769 | const Xmm xmm_tmp = Xmm(9); |
770 | const Vmm vmm_tmp = Vmm(9); |
771 | const Vmm vmm_inv_sqrtvar = Vmm(10); |
772 | const Vmm vmm_dsrc = Vmm(11); |
773 | const Vmm vmm_dd_scale_x = Vmm(12); |
774 | const Vmm vmm_dd_scale = Vmm(13); |
775 | const Vmm vmm_src = Vmm(14); |
776 | const Vmm vmm_mean = Vmm(15); |
777 | |
778 | const int bf16_emu_zmm_1_idx = 28; |
779 | const int bf16_emu_zmm_2_idx = 29; |
780 | const int bf16_emu_zmm_3_idx = 30; |
781 | const int bf16_emu_zmm_4_idx = 31; |
782 | const int tail_opmask_idx = 1; |
783 | |
784 | Address src_ptr(size_t offt = 0) { |
785 | return vmmword[reg_src + offt * src_d_.data_type_size()]; |
786 | } |
787 | |
788 | Address d_dst_ptr(size_t offt = 0) { |
789 | return vmmword[reg_diff_dst + offt * d_dst_d_.data_type_size()]; |
790 | } |
791 | |
792 | Address d_src_ptr(size_t offt = 0) { |
793 | return vmmword[reg_diff_src + offt * d_src_d_.data_type_size()]; |
794 | } |
795 | |
796 | Address scale_ptr(size_t offt = 0) { |
797 | return vmmword[reg_scale + offt * sizeof(float)]; |
798 | } |
799 | |
800 | virtual void reduce(Vmm vmm_src, Vmm vmm_tmp) = 0; |
801 | |
802 | void compute_dd_scales(size_t offt_elems, bool tail = false) { |
803 | Vmm vmm_ddst = vmm_dsrc; |
804 | io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail); |
805 | if (use_scale_) { |
806 | io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail); |
807 | uni_vmulps(vmm_ddst, vmm_ddst, vmm_scale); |
808 | } |
809 | io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail); |
810 | |
811 | uni_vaddps(vmm_dd_scale, vmm_dd_scale, vmm_ddst); |
812 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
813 | uni_vfmadd231ps(vmm_dd_scale_x, vmm_ddst, vmm_src); |
814 | }; |
815 | |
816 | void compute_diff_src(size_t offt_elems, bool tail = false) { |
817 | Vmm vmm_ddst = vmm_dsrc; |
818 | io_[d_dst_d_.data_type()]->load(d_dst_ptr(offt_elems), vmm_ddst, tail); |
819 | if (use_scale_) { |
820 | io_[f32]->load(scale_ptr(offt_elems), vmm_scale, tail); |
821 | uni_vmulps(vmm_dsrc, vmm_dsrc, vmm_scale); |
822 | } |
823 | if (calculate_diff_stats_) { |
824 | io_[src_d_.data_type()]->load(src_ptr(offt_elems), vmm_src, tail); |
825 | uni_vsubps(vmm_src, vmm_src, vmm_mean); |
826 | uni_vmulps(vmm_src, vmm_src, vmm_inv_sqrtvar); |
827 | uni_vfmadd213ps(vmm_src, vmm_dd_scale_x, vmm_dd_scale); |
828 | uni_vdivps(vmm_src, vmm_src, vmm_C); |
829 | uni_vsubps(vmm_dsrc, vmm_dsrc, vmm_src); |
830 | } |
831 | uni_vmulps(vmm_dsrc, vmm_dsrc, vmm_inv_sqrtvar); |
832 | io_[d_src_d_.data_type()]->store(vmm_dsrc, d_src_ptr(offt_elems), tail); |
833 | }; |
834 | |
835 | void generate() override { |
836 | const size_t c_src_size |
837 | = C_ * types::data_type_size(src_d_.data_type()); |
838 | const size_t c_ddst_size |
839 | = C_ * types::data_type_size(d_dst_d_.data_type()); |
840 | const size_t c_dsrc_size |
841 | = C_ * types::data_type_size(d_src_d_.data_type()); |
842 | static const size_t float_size = types::data_type_size(f32); |
843 | |
844 | preamble(); |
845 | |
846 | io_.init_bf16(); |
847 | if (axis_simd_tail_) io_.prepare_tail_mask(); |
848 | |
849 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
850 | mov(reg_src, ptr[reg_param + PARAM_OFF(src)]); |
851 | mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]); |
852 | mov(reg_diff_src, ptr[reg_param + PARAM_OFF(diff_src)]); |
853 | mov(reg_scale, ptr[reg_param + PARAM_OFF(ss)]); |
854 | |
855 | if (calculate_diff_stats_) |
856 | mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); |
857 | mov(reg_inv_sqrtvar, ptr[reg_param + PARAM_OFF(inv_sqrtvar)]); |
858 | mov(reg_block_end, ptr[reg_param + PARAM_OFF(block_size)]); |
859 | #undef PARAM_OFF |
860 | |
861 | mov(reg_tmp, float2int(C_)); |
862 | uni_vmovq(xmm_tmp, reg_tmp); |
863 | uni_vbroadcastss(vmm_C, xmm_tmp); |
864 | |
865 | // add block_start to block_size to define block_end |
866 | add(reg_block_end, reg_src); |
867 | |
868 | Label unroll_loop, end; |
869 | L(unroll_loop); |
870 | { |
871 | cmp(reg_block_end, reg_src); |
872 | jle(end, T_NEAR); |
873 | |
874 | uni_vmovss(xmm_tmp, dword[reg_inv_sqrtvar]); |
875 | uni_vbroadcastss(vmm_inv_sqrtvar, xmm_tmp); |
876 | |
877 | if (calculate_diff_stats_) { |
878 | uni_vmovss(xmm_tmp, dword[reg_mean]); |
879 | uni_vbroadcastss(vmm_mean, xmm_tmp); |
880 | |
881 | uni_vpxor(vmm_dd_scale, vmm_dd_scale, vmm_dd_scale); |
882 | uni_vpxor(vmm_dd_scale_x, vmm_dd_scale_x, vmm_dd_scale_x); |
883 | |
884 | for (int i = 0; i < axis_simd_full_; i++) |
885 | compute_dd_scales(i * simd_w_); |
886 | if (axis_simd_tail_) |
887 | compute_dd_scales(axis_simd_full_ * simd_w_, true); |
888 | |
889 | reduce(vmm_dd_scale, vmm_tmp); |
890 | reduce(vmm_dd_scale_x, vmm_tmp); |
891 | uni_vmulps(vmm_dd_scale_x, vmm_dd_scale_x, vmm_inv_sqrtvar); |
892 | } |
893 | |
894 | for (int i = 0; i < axis_simd_full_; i++) |
895 | compute_diff_src(i * simd_w_); |
896 | if (axis_simd_tail_) |
897 | compute_diff_src(axis_simd_full_ * simd_w_, true); |
898 | |
899 | add(reg_src, c_src_size); |
900 | add(reg_diff_dst, c_ddst_size); |
901 | add(reg_diff_src, c_dsrc_size); |
902 | if (calculate_diff_stats_) add(reg_mean, float_size); |
903 | add(reg_inv_sqrtvar, float_size); |
904 | jmp(unroll_loop); |
905 | } |
906 | L(end); |
907 | |
908 | postamble(); |
909 | } |
910 | }; |
911 | |
912 | template <cpu_isa_t isa> |
913 | struct jit_diff_data_kernel_t; |
914 | |
915 | template <> |
916 | struct jit_diff_data_kernel_t<avx512_core> |
917 | : public jit_diff_data_base_kernel_t<avx512_core> { |
918 | |
919 | using jit_diff_data_base_kernel_t::jit_diff_data_base_kernel_t; |
920 | |
921 | void reduce(Vmm vmm_src, Vmm vmm_tmp) override { |
922 | vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0x4E); // 256-bit shuffle |
923 | vaddps(vmm_src, vmm_src, vmm_tmp); |
924 | vshuff32x4(vmm_tmp, vmm_src, vmm_src, 0xB1); // 128/256-bit shuffle |
925 | vaddps(vmm_src, vmm_src, vmm_tmp); |
926 | vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle |
927 | vaddps(vmm_src, vmm_src, vmm_tmp); |
928 | vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle |
929 | vaddps(vmm_src, vmm_src, vmm_tmp); |
930 | } |
931 | }; |
932 | |
933 | template <> |
934 | struct jit_diff_data_kernel_t<avx2> : public jit_diff_data_base_kernel_t<avx2> { |
935 | |
936 | using jit_diff_data_base_kernel_t::jit_diff_data_base_kernel_t; |
937 | |
938 | void reduce(Vmm vmm_src, Vmm vmm_tmp) override { |
939 | vperm2f128(vmm_tmp, vmm_src, vmm_src, 0x1); // 128/256-bit shuffle |
940 | vaddps(vmm_src, vmm_src, vmm_tmp); |
941 | vshufps(vmm_tmp, vmm_src, vmm_src, 0x4E); // 64/128-bit shuffle |
942 | vaddps(vmm_src, vmm_src, vmm_tmp); |
943 | vshufps(vmm_tmp, vmm_src, vmm_src, 0xB1); // 32/64-bit shuffle |
944 | vaddps(vmm_src, vmm_src, vmm_tmp); |
945 | } |
946 | }; |
947 | |
948 | diff_data_kernel_t *diff_data_kernel_t::create( |
949 | const layer_normalization_pd_t *pd) { |
950 | if (mayiuse(avx512_core)) { |
951 | return new jit_diff_data_kernel_t<avx512_core>(pd); |
952 | } else if (mayiuse(avx2)) { |
953 | return new jit_diff_data_kernel_t<avx2>(pd); |
954 | } else { |
955 | assert(!"kernel is empty." ); |
956 | return nullptr; |
957 | } |
958 | } |
959 | |
960 | status_t jit_uni_layer_normalization_fwd_t::execute_forward( |
961 | const exec_ctx_t &ctx) const { |
962 | auto scratchpad = ctx.get_scratchpad_grantor(); |
963 | const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
964 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
965 | |
966 | auto scale = CTX_IN_MEM(const float *, DNNL_ARG_SCALE); |
967 | auto shift = CTX_IN_MEM(const float *, DNNL_ARG_SHIFT); |
968 | |
969 | float *mean, *variance; |
970 | if (pd()->use_tmp_stats()) { |
971 | mean = scratchpad.template get<float>(key_lnorm_tmp_mean); |
972 | variance = scratchpad.template get<float>(key_lnorm_tmp_var); |
973 | } else { |
974 | mean = pd()->stats_are_src() |
975 | ? const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN)) |
976 | : CTX_OUT_MEM(float *, DNNL_ARG_MEAN); |
977 | variance = pd()->stats_are_src() |
978 | ? const_cast<float *>( |
979 | CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE)) |
980 | : CTX_OUT_MEM(float *, DNNL_ARG_VARIANCE); |
981 | } |
982 | |
983 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
984 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
985 | |
986 | const memory_desc_wrapper src_d(pd()->src_md()); |
987 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
988 | |
989 | const dim_t N = pd()->across_axis(); |
990 | const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1]; |
991 | |
992 | parallel(0, [&](const int ithr, const int nthr) { |
993 | dim_t N_start = 0, N_end = 0; |
994 | balance211(N, nthr, ithr, N_start, N_end); |
995 | const char *const __restrict src_ptr |
996 | = reinterpret_cast<const char *>(src) |
997 | + N_start * C_padded * src_d.data_type_size(); |
998 | char *const __restrict dst_ptr = reinterpret_cast<char *>(dst) |
999 | + N_start * C_padded * dst_d.data_type_size(); |
1000 | const int block_size = N_end - N_start; |
1001 | (*stat_and_data_kernel_)(src_ptr, dst_ptr, scale, shift, &mean[N_start], |
1002 | &variance[N_start], src_scales, dst_scales, block_size); |
1003 | }); |
1004 | return status::success; |
1005 | } |
1006 | |
1007 | status_t jit_uni_layer_normalization_bwd_t::execute_backward( |
1008 | const exec_ctx_t &ctx) const { |
1009 | status_t status = status::success; |
1010 | |
1011 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1012 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
1013 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
1014 | auto scale = CTX_IN_MEM(float *, DNNL_ARG_SCALE); |
1015 | auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); |
1016 | |
1017 | auto diff_scale = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SCALE, status); |
1018 | CHECK(status); |
1019 | auto diff_shift = CTX_OUT_CLEAN_MEM(float *, DNNL_ARG_DIFF_SHIFT, status); |
1020 | CHECK(status); |
1021 | |
1022 | const float *mean, *variance; |
1023 | if (pd()->use_tmp_stats()) { |
1024 | mean = scratchpad.template get<float>(key_lnorm_tmp_mean); |
1025 | variance = scratchpad.template get<float>(key_lnorm_tmp_var); |
1026 | } else { |
1027 | mean = CTX_IN_MEM(const float *, DNNL_ARG_MEAN); |
1028 | variance = CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE); |
1029 | } |
1030 | |
1031 | float *const inv_sqrtvar |
1032 | = scratchpad.template get<float>(key_lnorm_inv_sqrtvar); |
1033 | |
1034 | const memory_desc_wrapper src_d(pd()->src_md()); |
1035 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1036 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
1037 | |
1038 | const dim_t N = pd()->across_axis(); |
1039 | const dim_t C = pd()->norm_axis(); |
1040 | const dim_t C_padded = src_d.padded_dims()[pd()->ndims() - 1]; |
1041 | |
1042 | float *reduce = scratchpad.template get<float>(key_lnorm_reduction); |
1043 | if (diff_scale == nullptr) |
1044 | diff_scale = scratchpad.template get<float>(key_lnorm_tmp_diff_ss); |
1045 | if (diff_shift == nullptr) { |
1046 | diff_shift = scratchpad.template get<float>(key_lnorm_tmp_diff_ss); |
1047 | } |
1048 | |
1049 | const int max_nthr = pd()->nthr_; |
1050 | |
1051 | parallel(max_nthr, [&](int ithr, int nthr) { |
1052 | dim_t N_start = 0, N_end = 0; |
1053 | balance211(N, nthr, ithr, N_start, N_end); |
1054 | const int block_size = N_end - N_start; |
1055 | const char *const __restrict src_ptr |
1056 | = reinterpret_cast<const char *>(src) |
1057 | + N_start * C_padded * src_d.data_type_size(); |
1058 | const char *const __restrict diff_dst_ptr |
1059 | = reinterpret_cast<const char *>(diff_dst) |
1060 | + N_start * C_padded * diff_dst_d.data_type_size(); |
1061 | |
1062 | float *my_diff_gamma = reduce + C * ithr; |
1063 | float *my_diff_beta = reduce + C * nthr + C * ithr; |
1064 | for (dim_t c = 0; c < C; c++) { |
1065 | my_diff_gamma[c] = 0.; |
1066 | my_diff_beta[c] = 0.; |
1067 | } |
1068 | (*diff_ss_kernel_)(src_ptr, diff_dst_ptr, my_diff_gamma, my_diff_beta, |
1069 | &mean[N_start], &variance[N_start], &inv_sqrtvar[N_start], |
1070 | block_size); |
1071 | }); |
1072 | |
1073 | parallel_nd(C, [&](dim_t c) { |
1074 | float diff_gamma = 0, diff_beta = 0; |
1075 | for (dim_t n = 0; n < max_nthr; n++) { |
1076 | diff_gamma += reduce[C * n + c]; |
1077 | diff_beta += reduce[C * max_nthr + C * n + c]; |
1078 | } |
1079 | diff_scale[c] = diff_gamma; |
1080 | diff_shift[c] = diff_beta; |
1081 | }); |
1082 | |
1083 | parallel(max_nthr, [&](int ithr, int nthr) { |
1084 | dim_t N_start = 0, N_end = 0; |
1085 | balance211(N, nthr, ithr, N_start, N_end); |
1086 | const int block_size = N_end - N_start; |
1087 | const char *const __restrict src_ptr |
1088 | = reinterpret_cast<const char *>(src) |
1089 | + N_start * C_padded * src_d.data_type_size(); |
1090 | const char *const __restrict diff_dst_ptr |
1091 | = reinterpret_cast<const char *>(diff_dst) |
1092 | + N_start * C_padded * diff_dst_d.data_type_size(); |
1093 | char *const __restrict diff_src_ptr = reinterpret_cast<char *>(diff_src) |
1094 | + N_start * C_padded * diff_src_d.data_type_size(); |
1095 | |
1096 | (*diff_data_kernel_)(src_ptr, diff_dst_ptr, diff_src_ptr, scale, |
1097 | &mean[N_start], &inv_sqrtvar[N_start], block_size); |
1098 | }); |
1099 | return status::success; |
1100 | } |
1101 | |
1102 | } // namespace x64 |
1103 | } // namespace cpu |
1104 | } // namespace impl |
1105 | } // namespace dnnl |
1106 | |