1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <array>
18#include <cmath>
19#include "common/c_types_map.hpp"
20#include "common/nstl.hpp"
21#include "common/utils.hpp"
22#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
23#include "cpu/x64/lrn/jit_uni_lrn_kernel.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::format_tag;
31
32#define IRB_LOOP(statement) \
33 if (1 == reg_block) { \
34 const int irb_off = 0; \
35 const int irb = this->reg_block_idx_ % vsum.size(); \
36 statement; \
37 MAYBE_UNUSED(irb_off); \
38 } else { \
39 for (int irb = 0; irb < reg_block; irb++) { \
40 const int irb_off = irb * this->single_pixel_offset_; \
41 statement; \
42 MAYBE_UNUSED(irb_off); \
43 } \
44 }
45
46using namespace Xbyak;
47
48template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
49 cpu_isa_t isa, data_type_t d_type>
50jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
51 void *code_ptr, size_t code_size, const char *name)
52 : jit_generator(name, code_ptr, code_size, true, isa)
53 , emulate_bfloat_(isa == avx512_core
54 && d_type == dnnl::impl::data_type::bf16
55 && !mayiuse(avx512_core_bf16))
56 , bf16_emu_(
57 emulate_bfloat_ ? utils::make_unique<bf16_emulation_t>(this,
58 bf16_emu_reserv_1_, bf16_emu_reserv_2_,
59 bf16_emu_reserv_3_, bf16_emu_scratch_, bf16_emu_reserv_4_)
60 : nullptr) {}
61
62template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
63 cpu_isa_t isa, data_type_t d_type>
64jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t(
65 const within_config_t &config, void *code_ptr, size_t code_size,
66 const char *name)
67 : jit_uni_lrn_kernel_t(code_ptr, code_size, name) {
68 if (config.dat_tag == nhwc)
69 single_pixel_offset_
70 = config.C * sizeof(typename prec_traits<d_type>::type);
71}
72
73template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
74 cpu_isa_t isa, data_type_t d_type>
75jit_uni_lrn_kernel_t<Derived<isa, d_type>>::~jit_uni_lrn_kernel_t() = default;
76
77template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
78 cpu_isa_t isa, data_type_t d_type>
79void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_loop(
80 const within_config_t &config, int max_reg_blocks, prop_kind_t pk) {
81 const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
82
83 const int lower_bound = (config.size - 1) / 2,
84 upper_bound = config.size - lower_bound - 1;
85
86 int pixel_count = 0;
87
88 for (int i = 0; i < lower_bound; ++i) {
89 pixel_count = 0;
90 for (int j = 0; j < lower_bound; ++j)
91 derived_ptr->within_body(-i, upper_bound, -j, upper_bound, config.W,
92 pk, 1, pixel_count++ * this->single_pixel_offset_);
93 derived_ptr->move_data_pointers(pixel_count, pk);
94
95 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, -i,
96 upper_bound, -lower_bound, upper_bound, config.W, pk);
97
98 pixel_count = 0;
99 for (int j = config.W - upper_bound; j < config.W; ++j)
100 derived_ptr->within_body(-i, upper_bound, -lower_bound,
101 config.W - 1 - j, config.W, pk, 1,
102 pixel_count++ * this->single_pixel_offset_);
103 derived_ptr->move_data_pointers(pixel_count, pk);
104 }
105
106 this->mov(h_, config.H - config.size + 1);
107 Label lrn_loop_h;
108 this->L(lrn_loop_h);
109 pixel_count = 0;
110 for (int j = 0; j < lower_bound; ++j)
111 derived_ptr->within_body(-lower_bound, upper_bound, -j, upper_bound,
112 config.W, pk, 1, pixel_count++ * this->single_pixel_offset_);
113 derived_ptr->move_data_pointers(pixel_count, pk);
114
115 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
116 -lower_bound, upper_bound, -lower_bound, upper_bound, config.W, pk);
117
118 pixel_count = 0;
119 for (int j = config.W - upper_bound; j < config.W; ++j)
120 derived_ptr->within_body(-lower_bound, upper_bound, -lower_bound,
121 config.W - 1 - j, config.W, pk, 1,
122 pixel_count++ * this->single_pixel_offset_);
123 derived_ptr->move_data_pointers(pixel_count, pk);
124
125 this->dec(h_);
126 this->cmp(h_, 0);
127 this->jne(lrn_loop_h, this->T_NEAR);
128
129 for (int i = config.H - upper_bound; i < config.H; ++i) {
130 pixel_count = 0;
131 for (int j = 0; j < lower_bound; ++j)
132 derived_ptr->within_body(-lower_bound, config.H - 1 - i, -j,
133 upper_bound, config.W, pk, 1,
134 pixel_count++ * this->single_pixel_offset_);
135 derived_ptr->move_data_pointers(pixel_count, pk);
136
137 within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks,
138 -lower_bound, config.H - 1 - i, -lower_bound, upper_bound,
139 config.W, pk);
140
141 pixel_count = 0;
142 for (int j = config.W - upper_bound; j < config.W; ++j)
143 derived_ptr->within_body(-lower_bound, config.H - 1 - i,
144 -lower_bound, config.W - 1 - j, config.W, pk, 1,
145 pixel_count++ * this->single_pixel_offset_);
146 derived_ptr->move_data_pointers(pixel_count, pk);
147 }
148}
149
150template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
151 cpu_isa_t isa, data_type_t d_type>
152void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_body_reg_blocked(
153 int loop_count, int max_reg_blocks, int hoff, int Hoff, int woff,
154 int Woff, int stride, prop_kind_t pk) {
155
156 const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this);
157 Label reg_block_compute_loop;
158
159 const auto res = std::div(loop_count, max_reg_blocks);
160 if (res.quot) {
161 this->mov(this->w_, res.quot);
162 this->L(reg_block_compute_loop);
163 derived_ptr->within_body(
164 hoff, Hoff, woff, Woff, stride, pk, max_reg_blocks, 0);
165 derived_ptr->move_data_pointers(max_reg_blocks, pk);
166 this->dec(this->w_);
167 this->cmp(this->w_, 0);
168 this->jne(reg_block_compute_loop, this->T_NEAR);
169 }
170 if (res.rem) {
171 derived_ptr->within_body(
172 hoff, Hoff, woff, Woff, stride, pk, res.rem, 0);
173 derived_ptr->move_data_pointers(res.rem, pk);
174 }
175}
176
177template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
178 cpu_isa_t isa, data_type_t d_type>
179void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_data(
180 const Vmm &reg, const Xbyak::Address &p) {
181 this->uni_vmovups(reg, p);
182}
183
184template <typename Gen, typename Reg, typename Addr>
185void load_bf16_data(Gen generator, const Reg &reg, const Addr &p) {
186 generator->vpmovzxwd(reg, p);
187 generator->vpslld(reg, reg, 0x10);
188}
189
190template <>
191void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core,
192 dnnl::impl::data_type::bf16>>::load_data(const Vmm &reg,
193 const Xbyak::Address &p) {
194 load_bf16_data(this, reg, p);
195}
196
197template <>
198void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core,
199 dnnl::impl::data_type::bf16>>::load_data(const Vmm &reg,
200 const Xbyak::Address &p) {
201 load_bf16_data(this, reg, p);
202}
203
204template <>
205void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core_fp16,
206 dnnl::impl::data_type::f16>>::load_data(const Vmm &reg,
207 const Xbyak::Address &p) {
208 vcvtph2ps(reg, p);
209}
210
211template <>
212void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core_fp16,
213 dnnl::impl::data_type::f16>>::load_data(const Vmm &reg,
214 const Xbyak::Address &p) {
215 vcvtph2ps(reg, p);
216}
217
218template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
219 cpu_isa_t isa, data_type_t d_type>
220void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::store_data(
221 const Xbyak::Address &addr, const Vmm &reg) {
222 this->uni_vmovups(addr, reg);
223}
224
225template <typename Gen, typename Bf16Emu>
226void store_bf16_data(
227 Gen generator, Bf16Emu emu, const Xbyak::Address &addr, const Zmm &zr) {
228 const Ymm yr = Ymm(zr.getIdx());
229 if (mayiuse(avx512_core_bf16))
230 generator->vcvtneps2bf16(yr, zr);
231 else
232 emu->vcvtneps2bf16(yr, zr);
233 generator->vmovdqu16(addr, yr);
234}
235
236template <>
237void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core,
238 dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
239 const Zmm &zr) {
240 store_bf16_data(this, bf16_emu_.get(), addr, zr);
241}
242
243template <>
244void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core,
245 dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr,
246 const Zmm &zr) {
247 store_bf16_data(this, bf16_emu_.get(), addr, zr);
248}
249
250template <>
251void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core_fp16,
252 dnnl::impl::data_type::f16>>::store_data(const Xbyak::Address &addr,
253 const Zmm &zr) {
254 vcvtps2ph(addr, zr, _op_mxcsr);
255}
256
257template <>
258void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core_fp16,
259 dnnl::impl::data_type::f16>>::store_data(const Xbyak::Address &addr,
260 const Zmm &zr) {
261 vcvtps2ph(addr, zr, _op_mxcsr);
262}
263
264template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
265 cpu_isa_t isa, data_type_t d_type>
266void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_constant(
267 float constant, const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
268 this->mov(this->imm_addr64_, float2int(constant));
269 this->uni_vmovq(x_constant, this->imm_addr64_);
270 this->vbroadcastss(v_constant, x_constant);
271}
272
273template <>
274void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<sse41,
275 dnnl::impl::data_type::f32>>::load_constant(float constant,
276 const Vmm &v_constant, const Xbyak::Xmm &x_constant) {
277 this->mov(this->imm_addr64_, float2int(constant));
278 this->uni_vmovq(x_constant, this->imm_addr64_);
279 this->shufps(x_constant, x_constant, 0);
280}
281
282//////////////////////////////////////////////////////////////////////////////
283// forward kernel
284template <cpu_isa_t isa, data_type_t d_type>
285void jit_uni_lrn_fwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
286 int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
287 int pixel_offset) {
288
289 static const std::array<Vmm, 3> vsum {{Vmm(2), Vmm(11), Vmm(20)}};
290 static const std::array<Vmm, 3> vsum2 {{Vmm(3), Vmm(12), Vmm(21)}};
291 static const std::array<Vmm, 3> vdst {{Vmm(4), Vmm(13), Vmm(22)}};
292 static const std::array<std::array<Vmm, 6u>, 3u> vtmp {
293 {{{Vmm(5), Vmm(6), Vmm(7), Vmm(8), Vmm(9), Vmm(14)}},
294 {{Vmm(18), Vmm(15), Vmm(16), Vmm(17), Vmm(29), Vmm(30)}},
295 {{Vmm(23), Vmm(24), Vmm(25), Vmm(26), Vmm(28), Vmm(31)}}}};
296 static const std::array<Vmm, 3> vscratch = {{Vmm(10), Vmm(19), Vmm(27)}};
297 static const std::size_t used_tmp_regs
298 = this->emulate_bfloat_ ? vtmp[0].size() - 2 : vtmp[0].size();
299
300 IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
301 for (int i = hoff; i <= Hoff; ++i) {
302 for (int j = woff; j <= Woff; ++j) {
303 if (i == 0 && j == 0) {
304 IRB_LOOP(this->load_data(
305 vdst[irb], this->ptr[src_ + pixel_offset + irb_off]));
306 IRB_LOOP(this->vfmadd231ps(vsum[irb], vdst[irb], vdst[irb]));
307 } else {
308 const auto idx = this->tempIdx_ % used_tmp_regs;
309 IRB_LOOP(this->load_data(vtmp[irb][idx],
310 this->ptr[(src_ + pixel_offset + irb_off)
311 + (i * stride + j)
312 * this->single_pixel_offset_]));
313 IRB_LOOP(this->vfmadd231ps(
314 vsum[irb], vtmp[irb][idx], vtmp[irb][idx]));
315 ++(this->tempIdx_);
316 }
317 }
318 }
319
320 this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
321
322 IRB_LOOP(this->vfmadd132ps(
323 vsum[irb], vk_, valpha_)); // ysum <- ysum*valpha_+yk_
324 IRB_LOOP(this->vmovaps(vscratch[irb], vsum[irb]));
325
326 IRB_LOOP(this->vmulps(vsum2[irb], vsum[irb], vsum[irb]));
327 IRB_LOOP(this->vmulps(
328 vsum[irb], vsum[irb], vsum2[irb])); // ysum = (ysum*valpha_+yk_)^3;
329 IRB_LOOP(this->vsqrtps(vsum[irb], vsum[irb]));
330 IRB_LOOP(this->vsqrtps(
331 vsum[irb], vsum[irb])); // ysum = (ysum*valpha_+yk_)^0.75
332 IRB_LOOP(this->vdivps(
333 vdst[irb], vdst[irb], vsum[irb])); // ydst <- ydst / ysum
334
335 if (pk_ != prop_kind::forward_inference) {
336 IRB_LOOP(this->store_data(
337 this->ptr[scratch_ + pixel_offset + irb_off], vsum[irb]));
338 IRB_LOOP(this->vdivps(vscratch[irb], vdst[irb], vscratch[irb]));
339 IRB_LOOP(this->store_data(
340 this->ptr[bwd_intermediate_res_ + pixel_offset + irb_off],
341 vscratch[irb]));
342 }
343
344 IRB_LOOP(this->store_data(
345 this->ptr[dst_ + pixel_offset + irb_off], vdst[irb]));
346
347 if (is_superset(isa, avx512_core))
348 this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
349}
350
351template <>
352void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::within_body(
353 int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk,
354 int reg_block, int pixel_offset) {
355
356 const Xbyak::Xmm &xtmp_lo = this->xmm2;
357 const Xbyak::Xmm &xtmp_hi = this->xmm3;
358 const Xbyak::Xmm &xsum_lo = this->xmm4;
359 const Xbyak::Xmm &xsum_hi = this->xmm5;
360 const Xbyak::Xmm &xdst_lo = this->xmm6;
361 const Xbyak::Xmm &xdst_hi = this->xmm7;
362 const Xbyak::Xmm &xsum2_lo = this->xmm8;
363 const Xbyak::Xmm &xsum2_hi = this->xmm9;
364
365 xorps(xsum_lo, xsum_lo);
366 xorps(xsum_hi, xsum_hi);
367 for (int i = hoff; i <= Hoff; ++i) {
368 for (int j = woff; j <= Woff; ++j) {
369 if (i == 0 && j == 0) {
370 movups(xdst_lo, ptr[src_ + pixel_offset]);
371 movups(xdst_hi, ptr[src_ + pixel_offset + 4 * sizeof(float)]);
372 mulps(xdst_lo, xdst_lo);
373 mulps(xdst_hi, xdst_hi);
374 addps(xsum_lo, xdst_lo);
375 addps(xsum_hi, xdst_hi);
376 } else {
377 movups(xtmp_lo,
378 ptr[src_ + pixel_offset
379 + (i * stride + j) * single_pixel_offset_]);
380 movups(xtmp_hi,
381 ptr[src_ + pixel_offset
382 + (i * stride + j) * single_pixel_offset_
383 + 4 * sizeof(float)]);
384 this->mulps(xtmp_lo, xtmp_lo);
385 this->mulps(xtmp_hi, xtmp_hi);
386 this->addps(xsum_lo, xtmp_lo);
387 this->addps(xsum_hi, xtmp_hi);
388 }
389 }
390 }
391 this->mulps(xsum_lo, xalpha_);
392 this->mulps(xsum_hi, xalpha_);
393 this->addps(xsum_lo, xk_);
394 this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
395 this->movaps(xtmp_lo, xsum_lo);
396 this->movaps(xtmp_hi, xsum_hi);
397 if (pk_ != prop_kind::forward_inference) {
398 this->movups(this->ptr[scratch_ + pixel_offset], xtmp_lo);
399 this->movups(this->ptr[scratch_ + pixel_offset + 4 * sizeof(float)],
400 xtmp_hi);
401 }
402 this->movaps(xsum2_lo, xsum_lo);
403 this->movaps(xsum2_hi, xsum_hi);
404 this->mulps(xsum2_lo, xsum_lo);
405 this->mulps(xsum2_hi, xsum_hi);
406 this->mulps(xsum_lo, xsum2_lo);
407 this->mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha_+xk_)^3;
408
409 this->sqrtps(xsum_lo, xsum_lo);
410 this->sqrtps(xsum_hi, xsum_hi);
411 this->sqrtps(xsum_lo, xsum_lo);
412 this->sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha_+xk_)^0.75
413
414 this->movups(xdst_lo, this->ptr[src_ + pixel_offset]);
415 this->movups(xdst_hi, this->ptr[src_ + pixel_offset + 4 * sizeof(float)]);
416 this->divps(xdst_lo, xsum_lo);
417 this->divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
418
419 this->movups(this->ptr[dst_ + pixel_offset], xdst_lo);
420 this->movups(this->ptr[dst_ + pixel_offset + 4 * sizeof(float)], xdst_hi);
421}
422
423template <cpu_isa_t isa, data_type_t d_type>
424void jit_uni_lrn_fwd_kernel_t<isa, d_type>::move_data_pointers(
425 int pixel_count, prop_kind_t pk) {
426
427 const int pixel_offset = this->single_pixel_offset_ * pixel_count;
428 this->add(src_, pixel_offset);
429 this->add(dst_, pixel_offset);
430 if (pk_ != prop_kind::forward_inference) {
431 this->add(scratch_, pixel_offset);
432 this->add(bwd_intermediate_res_, pixel_offset);
433 }
434}
435
436template <cpu_isa_t isa, data_type_t d_type>
437jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
438 const within_config_t &config, float A, float K, prop_kind_t pk,
439 void *code_ptr, size_t code_size)
440 : Base(config, code_ptr, code_size, jit_name())
441 , config_(lrn_config_t::within_config)
442 , within_config_(config)
443 , alpha_(A)
444 , k_(K)
445 , pk_(pk) {}
446
447template <cpu_isa_t isa, data_type_t d_type>
448void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(
449 const within_config_t &config) {
450 this->preamble();
451 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
452
453#define GET_OFF(field) offsetof(jit_args_fwd_t, field)
454 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
455 this->mov(dst_, this->ptr[this->param1 + GET_OFF(dst)]);
456 if (pk_ != prop_kind::forward_inference) {
457 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
458 this->mov(bwd_intermediate_res_,
459 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
460 }
461#undef GET_OFF
462
463 this->load_constant(alpha_, valpha_, xalpha_);
464 this->load_constant(k_, vk_, xk_);
465
466 static const int max_reg_blocks = is_superset(isa, avx512_core) ? 3 : 1;
467 this->within_loop(config, max_reg_blocks, pk_);
468
469 this->postamble();
470}
471
472template <cpu_isa_t isa, data_type_t d_type>
473jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
474 const struct nchw8c_across_t &J, float A, float K, prop_kind_t pk,
475 void *code_ptr, size_t code_size)
476 : Base(code_ptr, code_size, jit_name())
477 , config_(lrn_config_t::nchw8c_across)
478 , nchw8c_across_(J)
479 , alpha_(A)
480 , k_(K)
481 , pk_(pk) {}
482
483template <cpu_isa_t isa, data_type_t d_type>
484void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
485 const Xbyak::Reg64 &t = this->rsp;
486 const Xbyak::Reg64 &hw = this->r9;
487 const Xbyak::Xmm &xsrc_prev = this->xmm2;
488 const Xbyak::Ymm &ysrc = this->ymm3;
489 const Xbyak::Ymm &yc = this->ymm3;
490 const Xbyak::Xmm &xsrc_next = this->xmm4;
491 const Xbyak::Ymm &ya = this->ymm5;
492 const Xbyak::Ymm &yb = this->ymm6;
493 const Xbyak::Ymm &yd = this->ymm7;
494 const Xbyak::Ymm &ye = this->ymm8;
495 const Xbyak::Ymm &ysum = this->ymm9;
496 const Xbyak::Ymm &ysum2 = this->ymm10;
497 const Xbyak::Ymm &ydst = this->ymm11;
498 const Xbyak::Ymm &ybase = this->ymm12;
499
500 this->preamble();
501 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
502
503 this->mov(src_, this->ptr[this->param1 + 0]);
504 this->mov(dst_, this->ptr[this->param1 + 8]);
505 if (pk_ != prop_kind::forward_inference)
506 this->mov(scratch_, this->ptr[this->param1 + 16]);
507 this->sub(t, 64);
508 this->mov(this->imm_addr64_, float2int(this->alpha_));
509 this->vmovq(xalpha_, this->imm_addr64_);
510 this->vbroadcastss(valpha_, xalpha_);
511
512 this->mov(this->imm_addr64_, float2int(this->k_));
513 this->vmovq(xk_, this->imm_addr64_);
514 this->vbroadcastss(yk_, xk_);
515
516 if (J.version == -1) {
517 this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
518 this->vmovups(this->ptr[t + 0], xsrc_prev);
519 }
520 if (J.version == +1) {
521 this->vxorps(xsrc_next, xsrc_next, xsrc_next);
522 this->vmovups(this->ptr[t + 48], xsrc_next);
523 }
524
525 this->mov(hw, J.H * J.W);
526
527 Label lrn_loop;
528 this->L(lrn_loop);
529
530 if (J.version != -1)
531 this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
532 this->vmovups(ysrc, this->ptr[src_]);
533 if (J.version != +1)
534 this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
535
536 if (J.version != -1) this->vmovups(this->ptr[t + 0], xsrc_prev);
537 this->vmovups(this->ptr[t + 16], ysrc);
538 if (J.version != +1) this->vmovups(this->ptr[t + 48], xsrc_next);
539
540 this->vmovups(ya, this->ptr[t + 16 - 8]);
541 this->vmovups(yb, this->ptr[t + 16 - 4]);
542 this->vmovups(yd, this->ptr[t + 16 + 4]);
543 this->vmovups(ye, this->ptr[t + 16 + 8]);
544 this->vmulps(ysum, yc, yc);
545 this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
546 this->vfmadd231ps(ysum, yb, yb);
547 this->vfmadd231ps(ysum, yd, yd);
548 this->vfmadd231ps(ysum, ye, ye);
549 this->vfmadd132ps(ysum, yk_, valpha_); // ysum <- ysum*valpha_+yk_
550
551 this->vmovaps(ybase, ysum);
552 if (pk_ != prop_kind::forward_inference)
553 this->vmovups(this->ptr[scratch_], ybase);
554 this->vmulps(ysum2, ysum, ysum);
555 this->vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
556 this->vsqrtps(ysum, ysum);
557 this->vsqrtps(ysum, ysum); // ysum = ybase^0.75
558 this->vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
559 this->vmovups(this->ptr[dst_], ydst);
560
561 this->add(src_, 32);
562 this->add(dst_, 32);
563 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
564 this->dec(hw);
565 this->cmp(hw, 0);
566 this->jne(lrn_loop, this->T_NEAR);
567
568 this->add(t, 64);
569 this->postamble();
570}
571
572template <>
573jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
574 jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t &J, float A,
575 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
576 : Base(code_ptr, code_size, jit_name())
577 , config_(lrn_config_t::nchw8c_across)
578 , nchw8c_across_(J)
579 , alpha_(A)
580 , k_(K)
581 , pk_(pk) {}
582
583template <>
584void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
585 const nchw8c_across_t &J) {
586
587 const Xbyak::Reg64 &t = this->rsp;
588 const Xbyak::Reg64 &hw = this->r9;
589 const Xbyak::Xmm &xsrc_lo = this->xmm2;
590 const Xbyak::Xmm &xsrc_hi = this->xmm3;
591 const Xbyak::Xmm &xc_lo = this->xmm4;
592 const Xbyak::Xmm &xc_hi = this->xmm5;
593 const Xbyak::Xmm &xsum_lo = xc_lo;
594 const Xbyak::Xmm &xsum_hi = xc_hi;
595 const Xbyak::Xmm &xsrc_prev = this->xmm6;
596 const Xbyak::Xmm &xsrc_next = this->xmm7;
597 const Xbyak::Xmm &xa_lo = this->xmm8;
598 const Xbyak::Xmm &xa_hi = this->xmm9;
599 const Xbyak::Xmm &xb_lo = this->xmm10;
600 const Xbyak::Xmm &xb_hi = this->xmm11;
601 const Xbyak::Xmm &xd_lo = this->xmm12;
602 const Xbyak::Xmm &xd_hi = this->xmm13;
603 const Xbyak::Xmm &xe_lo = this->xmm14;
604 const Xbyak::Xmm &xe_hi = this->xmm15;
605 const Xbyak::Xmm &xbase_lo = this->xmm14;
606 const Xbyak::Xmm &xbase_hi = this->xmm15;
607
608 this->preamble();
609 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
610
611 this->mov(src_, this->ptr[this->param1 + 0]);
612 this->mov(dst_, this->ptr[this->param1 + 8]);
613 if (pk_ != prop_kind::forward_inference)
614 this->mov(scratch_, this->ptr[this->param1 + 16]);
615 this->sub(t, 64);
616 this->mov(this->imm_addr64_, float2int(this->alpha_));
617 this->movq(xalpha_, this->imm_addr64_);
618 this->shufps(xalpha_, xalpha_, 0);
619
620 this->mov(this->imm_addr64_, float2int(this->k_));
621 this->movq(xk_, this->imm_addr64_);
622 this->shufps(xk_, xk_, 0);
623
624 if (J.version == -1) {
625 this->xorps(xsrc_prev, xsrc_prev);
626 this->movups(this->ptr[t + 0], xsrc_prev);
627 }
628 if (J.version == +1) {
629 this->xorps(xsrc_next, xsrc_next);
630 this->movups(this->ptr[t + 48], xsrc_next);
631 }
632
633 this->mov(hw, J.H * J.W);
634 Label lrn_loop;
635 L(lrn_loop);
636
637 if (J.version != -1)
638 this->movups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
639 this->movups(xsrc_lo, this->ptr[src_]);
640 this->movups(xsrc_hi, this->ptr[src_ + 4 * sizeof(float)]);
641 if (J.version != +1)
642 this->movups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
643
644 if (J.version != -1) this->movups(this->ptr[t + 0], xsrc_prev);
645 this->movups(this->ptr[t + 16], xsrc_lo);
646 this->movups(this->ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
647 if (J.version != +1) this->movups(this->ptr[t + 48], xsrc_next);
648
649 this->movups(xa_lo, this->ptr[t + 16 - 8]);
650 this->movups(xa_hi, this->ptr[t + 16 - 8 + 4 * sizeof(float)]);
651 this->movups(xb_lo, this->ptr[t + 16 - 4]);
652 this->movups(xb_hi, this->ptr[t + 16 - 4 + 4 * sizeof(float)]);
653 this->movups(xd_lo, this->ptr[t + 16 + 4]);
654 this->movups(xd_hi, this->ptr[t + 16 + 4 + 4 * sizeof(float)]);
655 this->movups(xe_lo, this->ptr[t + 16 + 8]);
656 this->movups(xe_hi, this->ptr[t + 16 + 8 + 4 * sizeof(float)]);
657 this->movaps(xc_lo, xsrc_lo);
658 this->movaps(xc_hi, xsrc_hi);
659 this->mulps(xsum_lo, xc_lo);
660 this->mulps(xsum_hi, xc_hi);
661 this->mulps(xa_lo, xa_lo);
662 this->mulps(xa_hi, xa_hi);
663 this->addps(xsum_lo, xa_lo);
664 this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
665 this->mulps(xb_lo, xb_lo);
666 this->mulps(xb_hi, xb_hi);
667 this->addps(xsum_lo, xb_lo);
668 this->addps(xsum_hi, xb_hi);
669 this->mulps(xd_lo, xd_lo);
670 this->mulps(xd_hi, xd_hi);
671 this->addps(xsum_lo, xd_lo);
672 this->addps(xsum_hi, xd_hi);
673 this->mulps(xe_lo, xe_lo);
674 this->mulps(xe_hi, xe_hi);
675 this->addps(xsum_lo, xe_lo);
676 this->addps(xsum_hi, xe_hi);
677
678 this->mulps(xsum_lo, xalpha_);
679 this->mulps(xsum_hi, xalpha_);
680 this->addps(xsum_lo, xk_);
681 this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_
682
683 this->movaps(xbase_lo, xsum_lo);
684 this->movaps(xbase_hi, xsum_hi);
685 if (pk_ != prop_kind::forward_inference) {
686 this->movups(this->ptr[scratch_], xbase_lo);
687 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
688 }
689 this->mulps(xsum_lo, xsum_lo);
690 this->mulps(xsum_hi, xsum_hi);
691 this->mulps(xsum_lo, xbase_lo);
692 this->mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
693 this->sqrtps(xsum_lo, xsum_lo);
694 this->sqrtps(xsum_hi, xsum_hi);
695 this->sqrtps(xsum_lo, xsum_lo);
696 this->sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
697 this->divps(xsrc_lo, xsum_lo);
698 this->divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
699 this->movups(this->ptr[dst_], xsrc_lo);
700 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xsrc_hi);
701
702 this->add(src_, 32);
703 this->add(dst_, 32);
704 if (pk_ != prop_kind::forward_inference) add(scratch_, 32);
705 this->dec(hw);
706 this->cmp(hw, 0);
707 this->jne(lrn_loop, this->T_NEAR);
708
709 this->add(t, 64);
710 this->postamble();
711}
712
713template <cpu_isa_t isa, data_type_t d_type>
714jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
715 const struct nhwc_across_t &J, float A, float K, prop_kind_t pk,
716 void *code_ptr, size_t code_size)
717 : Base(code_ptr, code_size, jit_name())
718 , config_(lrn_config_t::nhwc_across)
719 , nhwc_across_(J)
720 , alpha_(A)
721 , k_(K)
722 , pk_(pk) {}
723
724template <cpu_isa_t isa, data_type_t d_type>
725void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nhwc_across_t &J) {
726 static const uint32_t mask[] = {0, 0, 0x80000000, 0x80000000, 0x80000000,
727 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0, 0};
728
729 const Xbyak::Reg64 &c = this->r9;
730 const Xbyak::Ymm &ya = this->ymm2;
731 const Xbyak::Ymm &yb = this->ymm3;
732 const Xbyak::Ymm &yc = this->ymm4;
733 const Xbyak::Ymm &yd = this->ymm5;
734 const Xbyak::Ymm &ye = this->ymm6;
735 const Xbyak::Ymm &ysum = this->ymm7;
736 const Xbyak::Ymm &ydst = this->ymm8;
737 const Xbyak::Ymm &ybase = this->ymm9;
738 const Xbyak::Ymm &ymask = this->ymm10;
739
740 this->preamble();
741 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
742
743 this->mov(src_, this->ptr[this->param1 + 0]);
744 this->mov(dst_, this->ptr[this->param1 + 8]);
745 if (pk_ != prop_kind::forward_inference)
746 this->mov(scratch_, this->ptr[this->param1 + 16]);
747 this->mov(this->imm_addr64_, float2int(this->alpha_));
748 this->vmovq(xalpha_, this->imm_addr64_);
749 this->vbroadcastss(valpha_, xalpha_);
750
751 this->mov(this->imm_addr64_, float2int(this->k_));
752 this->vmovq(xk_, this->imm_addr64_);
753 this->vbroadcastss(yk_, xk_);
754
755 this->vxorps(ysum, ysum, ysum);
756
757 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[0]));
758 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
759 this->vmaskmovps(ya, ymask, this->ptr[src_ - 8]);
760 this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
761
762 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[1]));
763 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
764 this->vmaskmovps(yb, ymask, this->ptr[src_ - 4]);
765 this->vfmadd231ps(ysum, yb, yb);
766
767 this->mov(c, J.C / 8 - 1);
768 Label lrn_loop;
769 this->L(lrn_loop);
770
771 this->vmovups(yc, this->ptr[src_]);
772 this->vmovups(yd, this->ptr[src_ + 4]);
773 this->vmovups(ye, this->ptr[src_ + 8]);
774 this->vfmadd231ps(ysum, yc, yc);
775 this->vfmadd231ps(ysum, yd, yd);
776 this->vfmadd231ps(ysum, ye, ye);
777
778 this->vmovups(ydst, ysum);
779 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
780
781 this->vmovaps(ybase, ydst);
782 if (pk_ != prop_kind::forward_inference)
783 this->vmovups(this->ptr[scratch_], ybase);
784 this->vmulps(ydst, ydst, ydst);
785 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
786 this->vsqrtps(ydst, ydst);
787 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
788
789 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
790 this->vmovups(this->ptr[dst_], ydst);
791
792 this->vxorps(ysum, ysum, ysum);
793
794 this->add(src_, 32);
795 this->add(dst_, 32);
796 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
797
798 this->vmovups(ya, this->ptr[src_ - 8]);
799 this->vfmadd231ps(ysum, ya, ya);
800 this->vmovups(yb, this->ptr[src_ - 4]);
801 this->vfmadd231ps(ysum, yb, yb);
802
803 this->dec(c);
804 this->cmp(c, 0);
805 this->jne(lrn_loop, this->T_NEAR);
806
807 this->vmovups(yc, this->ptr[src_]);
808 this->vfmadd231ps(ysum, yc, yc);
809
810 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[2]));
811 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
812 this->vmaskmovps(yd, ymask, this->ptr[src_ + 4]);
813 this->vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
814
815 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[3]));
816 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
817 this->vmaskmovps(ye, ymask, this->ptr[src_ + 8]);
818 this->vfmadd231ps(ysum, ye, ye);
819
820 this->vmovups(ydst, ysum);
821 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
822
823 this->vmovaps(ybase, ydst);
824 if (pk_ != prop_kind::forward_inference)
825 this->vmovups(this->ptr[scratch_], ybase);
826 this->vmulps(ydst, ydst, ydst);
827 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
828 this->vsqrtps(ydst, ydst);
829 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
830 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
831
832 this->vmovups(this->ptr[dst_], ydst);
833
834 this->postamble();
835}
836
837template <>
838jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
839 jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t &J, float A,
840 float K, prop_kind_t pk, void *code_ptr, size_t code_size)
841 : Base(code_ptr, code_size, jit_name())
842 , config_(lrn_config_t::nhwc_across)
843 , nhwc_across_(J)
844 , alpha_(A)
845 , k_(K)
846 , pk_(pk) {}
847
848template <>
849void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
850 const nhwc_across_t &J) {
851 static uint32_t store[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
852 const Xbyak::Reg64 c = this->r9;
853
854 const Xbyak::Xmm &xdst_lo = this->xmm0;
855 const Xbyak::Xmm &xdst_hi = this->xmm1;
856 const Xbyak::Xmm &xa_lo = this->xmm2;
857 const Xbyak::Xmm &xa_hi = this->xmm3;
858 const Xbyak::Xmm &xb_lo = this->xmm2;
859 const Xbyak::Xmm &xb_hi = this->xmm3;
860 const Xbyak::Xmm &xc_lo = this->xmm4;
861 const Xbyak::Xmm &xc_hi = this->xmm5;
862 const Xbyak::Xmm &xd_lo = this->xmm6;
863 const Xbyak::Xmm &xd_hi = this->xmm7;
864 const Xbyak::Xmm &xe_lo = this->xmm8;
865 const Xbyak::Xmm &xe_hi = this->xmm9;
866 const Xbyak::Xmm &xsum_lo = this->xmm10;
867 const Xbyak::Xmm &xsum_hi = this->xmm11;
868 // unused: xmm12, xmm13;
869 const Xbyak::Xmm &xbase_lo = this->xmm14;
870 const Xbyak::Xmm &xbase_hi = this->xmm15;
871
872 this->preamble();
873 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
874
875 this->mov(src_, this->ptr[this->param1 + 0]);
876 this->mov(dst_, this->ptr[this->param1 + 8]);
877 if (pk_ != prop_kind::forward_inference)
878 mov(scratch_, this->ptr[this->param1 + 16]);
879 this->mov(this->imm_addr64_, float2int(this->alpha_));
880 this->movq(xalpha_, this->imm_addr64_);
881 this->shufps(xalpha_, xalpha_, 0);
882
883 this->mov(this->imm_addr64_, float2int(this->k_));
884 this->movq(xk_, this->imm_addr64_);
885 this->shufps(xk_, xk_, 0);
886
887 this->mov(store_addr_, reinterpret_cast<size_t>(&store[0]));
888 this->and_(store_addr_, -15);
889 this->movups(this->ptr[store_addr_], xalpha_);
890 this->movups(this->ptr[store_addr_ + 4 * sizeof(float)], xk_);
891
892 this->xorps(xsum_lo, xsum_lo);
893 this->xorps(xsum_hi, xsum_hi);
894
895 /* load the 2 first blocks of channels
896 * block: | -- low -- | -- hi -- |
897 * C: [c1,c2,c3,c4,c5,c6,c7,c8]
898 * xa_lo << 2 [0,0,c1,c2]
899 * xa_hi [c3,c4,c5,c6]
900 * xb_lo << 1 [0,c1,c2,c3]
901 * xb_hi [c4,c5,c6,c7]
902 * | -- data -- (...)
903 * ^ memory boundary
904 */
905 this->movups(xa_lo, this->ptr[src_]);
906 this->movups(xa_hi, this->ptr[src_ + 2 * sizeof(float)]);
907 this->pslldq(xa_lo, 2 * sizeof(float));
908 this->mulps(xa_lo, xa_lo);
909 this->mulps(xa_hi, xa_hi);
910 this->addps(xsum_lo, xa_lo);
911 this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
912
913 this->movups(xb_lo, this->ptr[src_]);
914 this->movups(xb_hi, this->ptr[src_ + 3 * sizeof(float)]);
915 this->pslldq(xb_lo, 1 * sizeof(float));
916 this->mulps(xb_lo, xb_lo);
917 this->mulps(xb_hi, xb_hi);
918 this->addps(xsum_lo, xb_lo);
919 this->addps(xsum_hi, xb_hi);
920
921 this->mov(c, J.C / 8 - 1);
922 Label lrn_loop;
923 this->L(lrn_loop);
924
925 this->movups(xc_lo, this->ptr[src_]);
926 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
927 this->movups(xd_lo, this->ptr[src_ + 4]);
928 this->movups(xd_hi, this->ptr[src_ + 4 + 4 * sizeof(float)]);
929 this->movups(xe_lo, this->ptr[src_ + 8]);
930 this->movups(xe_hi, this->ptr[src_ + 8 + 4 * sizeof(float)]);
931 this->mulps(xc_lo, xc_lo);
932 this->mulps(xc_hi, xc_hi);
933 this->addps(xsum_lo, xc_lo);
934 this->addps(xsum_hi, xc_hi);
935 this->mulps(xd_lo, xd_lo);
936 this->mulps(xd_hi, xd_hi);
937 this->addps(xsum_lo, xd_lo);
938 this->addps(xsum_hi, xd_hi);
939 this->mulps(xe_lo, xe_lo);
940 this->mulps(xe_hi, xe_hi);
941 this->addps(xsum_lo, xe_lo);
942 this->addps(xsum_hi, xe_hi);
943
944 this->movaps(xdst_lo, xsum_lo);
945 this->movaps(xdst_hi, xsum_hi);
946 // xdst <- xsum*xalpha_+xk_
947 this->mulps(xdst_lo, this->ptr[store_addr_]);
948 this->mulps(xdst_hi, this->ptr[store_addr_]);
949 this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
950 this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
951
952 this->movaps(xbase_lo, xdst_lo);
953 this->movaps(xbase_hi, xdst_hi);
954 if (pk_ != prop_kind::forward_inference) {
955 this->movups(this->ptr[scratch_], xbase_lo);
956 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
957 }
958 this->mulps(xdst_lo, xdst_lo);
959 this->mulps(xdst_hi, xdst_hi);
960 this->mulps(xdst_lo, xbase_lo);
961 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
962 this->sqrtps(xdst_lo, xdst_lo);
963 this->sqrtps(xdst_hi, xdst_hi);
964 this->sqrtps(xdst_lo, xdst_lo);
965 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
966
967 this->movups(xc_lo, this->ptr[src_]);
968 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
969 this->divps(xc_lo, xdst_lo);
970 this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
971 this->movups(this->ptr[dst_], xc_lo);
972 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
973
974 this->xorps(xsum_lo, xsum_lo);
975 this->xorps(xsum_hi, xsum_hi);
976
977 this->add(src_, 32);
978 this->add(dst_, 32);
979 if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32);
980
981 this->movups(xa_lo, this->ptr[src_ - 8]);
982 this->movups(xa_hi, this->ptr[src_ - 8 + 4 * sizeof(float)]);
983 this->mulps(xa_lo, xa_lo);
984 this->mulps(xa_hi, xa_hi);
985 this->addps(xsum_lo, xa_lo);
986 this->addps(xsum_hi, xa_hi);
987 this->movups(xb_lo, this->ptr[src_ - 4]);
988 this->movups(xb_hi, this->ptr[src_ - 4 + 4 * sizeof(float)]);
989 this->mulps(xb_lo, xb_lo);
990 this->mulps(xb_hi, xb_hi);
991 this->addps(xsum_lo, xb_lo);
992 this->addps(xsum_hi, xb_hi);
993
994 this->dec(c);
995 this->cmp(c, 0);
996 this->jne(lrn_loop, this->T_NEAR);
997
998 /* compute last 3 blocks of channels:
999 * block: | -- low -- | -- hi -- |
1000 * C: [c1,c2,c3,c4,c5,c6,c7,c8]
1001 * xc_lo|xc_hi [c1,c2,c3,c4|c5,c6,c7,c8]
1002 * xd_lo [c2,c3,c4,c5]
1003 * xd_hi >> 1 [c6,c7,c8, 0]
1004 * xe_lo [c3,c4,c5,c6]
1005 * xe_hi >> 2 [c7,c8, 0, 0]
1006 * (...) -- data -- | -- illegal reading -- (...)
1007 * ^ memory boundary
1008 */
1009 this->movups(xc_lo, this->ptr[src_]);
1010 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
1011 this->mulps(xc_lo, xc_lo);
1012 this->mulps(xc_hi, xc_hi);
1013 this->addps(xsum_lo, xc_lo);
1014 this->addps(xsum_hi, xc_hi);
1015
1016 this->movups(xd_lo, this->ptr[src_ + 1 * sizeof(float)]);
1017 this->movups(xd_hi, this->ptr[src_ + 4 * sizeof(float)]);
1018 this->psrldq(xd_hi, 1 * sizeof(float));
1019 this->mulps(xd_lo, xd_lo);
1020 this->mulps(xd_hi, xd_hi);
1021 this->addps(xsum_lo, xd_lo);
1022 this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1023
1024 this->movups(xe_lo, this->ptr[src_ + 2 * sizeof(float)]);
1025 this->movups(xe_hi, this->ptr[src_ + 4 * sizeof(float)]);
1026 this->psrldq(xe_hi, 2 * sizeof(float));
1027 this->mulps(xe_lo, xe_lo);
1028 this->mulps(xe_hi, xe_hi);
1029 this->addps(xsum_lo, xe_lo);
1030 this->addps(xsum_hi, xe_hi);
1031
1032 this->movups(xdst_lo, xsum_lo);
1033 this->movups(xdst_hi, xsum_hi);
1034 // xdst <- xsum*xalpha_+xk_
1035 this->mulps(xdst_lo, this->ptr[store_addr_]);
1036 this->mulps(xdst_hi, this->ptr[store_addr_]);
1037 this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]);
1038 this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]);
1039
1040 this->movaps(xbase_lo, xdst_lo);
1041 this->movaps(xbase_hi, xdst_hi);
1042 if (pk_ != prop_kind::forward_inference) {
1043 this->movups(this->ptr[scratch_], xbase_lo);
1044 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1045 }
1046 this->mulps(xdst_lo, xdst_lo);
1047 this->mulps(xdst_hi, xdst_hi);
1048 this->mulps(xdst_lo, xbase_lo);
1049 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1050 this->sqrtps(xdst_lo, xdst_lo);
1051 this->sqrtps(xdst_hi, xdst_hi);
1052 this->sqrtps(xdst_lo, xdst_lo);
1053 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1054 this->movups(xc_lo, this->ptr[src_]);
1055 this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]);
1056 this->divps(xc_lo, xdst_lo);
1057 this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1058
1059 this->movups(this->ptr[dst_], xc_lo);
1060 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi);
1061
1062 this->postamble();
1063}
1064
1065template <>
1066void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::nchw_body(
1067 int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya,
1068 Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye,
1069 Xbyak::Ymm ysum) {}
1070
1071template <cpu_isa_t isa, data_type_t d_type>
1072void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body(int tail, int HW,
1073 prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya, Xbyak::Ymm yb,
1074 Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye, Xbyak::Ymm ysum) {
1075 const Xbyak::Ymm &ydst = this->ymm14;
1076 const Xbyak::Ymm &ybase = this->ymm15;
1077
1078 this->vfmadd231ps(ysum, ye, ye);
1079
1080 this->vmovups(ydst, ysum);
1081 this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_
1082
1083 this->vmovaps(ybase, ydst);
1084 if (pk_ != prop_kind::forward_inference) {
1085 if (tail != 0)
1086 this->vmaskmovps(this->ptr[scratch_], ymask, ybase);
1087 else
1088 this->vmovups(this->ptr[scratch_], ybase);
1089 }
1090 this->vmulps(ydst, ydst, ydst);
1091 this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3;
1092 this->vsqrtps(ydst, ydst);
1093 this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75
1094 this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75
1095
1096 if (tail != 0)
1097 this->vmaskmovps(this->ptr[dst_], ymask, ydst);
1098 else
1099 this->vmovups(this->ptr[dst_], ydst);
1100
1101 this->vfnmadd231ps(ysum, ya, ya);
1102 this->vmovups(ya, yb);
1103 this->vmovups(yb, yc);
1104 this->vmovups(yc, yd);
1105 this->vmovups(yd, ye);
1106}
1107
1108template <cpu_isa_t isa, data_type_t d_type>
1109void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_tail_sse41(int tail,
1110 Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {}
1111
1112template <>
1113void jit_uni_lrn_fwd_kernel_t<sse41,
1114 dnnl::impl::data_type::f32>::nchw_tail_sse41(int tail,
1115 Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {
1116 Xbyak::Xmm xmm_tmp = xmm10;
1117 this->movaps(xmm_tmp, xtail_hi);
1118
1119 if (tail > 3) {
1120 /* Store upper-half directly */
1121 this->movups(this->ptr[reg_dst + (tail - 4) * sizeof(float)], xtail_hi);
1122 this->movaps(xmm_tmp, xtail_lo);
1123 tail -= 4;
1124 }
1125 if (tail > 0) {
1126 /* Store on a single-element basis when 'tail' overlaps
1127 * with 'src_' */
1128 this->psrldq(xmm_tmp, (4 - tail) * sizeof(float));
1129 this->movss(this->ptr[reg_dst], xmm_tmp);
1130
1131 for (int i = 1; i < tail; i++) {
1132 this->psrldq(xmm_tmp, sizeof(float));
1133 this->movss(this->ptr[reg_dst + i * sizeof(float)], xmm_tmp);
1134 }
1135 }
1136}
1137
1138template <>
1139void jit_uni_lrn_fwd_kernel_t<sse41,
1140 dnnl::impl::data_type::f32>::nchw_body_sse41(int tail, int HW,
1141 prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1142 Xbyak::Xmm xsum_hi) {
1143 const Xbyak::Xmm &xdst_lo = this->xmm0;
1144 const Xbyak::Xmm &xdst_hi = this->xmm1;
1145 const Xbyak::Xmm &xbase_lo = this->xmm6;
1146 const Xbyak::Xmm &xbase_hi = this->xmm7;
1147 const Xbyak::Xmm &xtmp_lo = this->xmm8;
1148 const Xbyak::Xmm &xtmp_hi = this->xmm9;
1149 const Xbyak::Xmm &xa_lo = this->xmm6;
1150 const Xbyak::Xmm &xa_hi = this->xmm7;
1151 const Xbyak::Xmm &xb_lo = this->xmm8;
1152 const Xbyak::Xmm &xb_hi = this->xmm9;
1153 const Xbyak::Xmm &xc_lo = this->xmm10;
1154 const Xbyak::Xmm &xc_hi = this->xmm11;
1155 const Xbyak::Xmm &xd_lo = this->xmm12;
1156 const Xbyak::Xmm &xd_hi = this->xmm13;
1157
1158 // store xe
1159 this->movaps(this->ptr[store_addr_ + 10 * 4 * sizeof(float)], xe_lo);
1160 this->movaps(this->ptr[store_addr_ + 11 * 4 * sizeof(float)], xe_hi);
1161
1162 this->mulps(xe_lo, xe_lo);
1163 this->mulps(xe_hi, xe_hi);
1164 this->addps(xsum_lo, xe_lo);
1165 this->addps(xsum_hi, xe_hi);
1166
1167 // xdst <- xsum*xalpha_+xk_
1168 this->movaps(xdst_lo, xsum_lo);
1169 this->movaps(xdst_hi, xsum_hi);
1170 this->mulps(xdst_lo, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1171 this->mulps(xdst_hi, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]);
1172 this->addps(xdst_lo, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1173 this->addps(xdst_hi, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]);
1174
1175 this->movaps(xbase_lo, xdst_lo);
1176 this->movaps(xbase_hi, xdst_hi);
1177 if (pk_ != prop_kind::forward_inference) {
1178 if (tail != 0) {
1179 nchw_tail_sse41(tail, scratch_, xbase_lo, xbase_hi);
1180 } else {
1181 this->movups(this->ptr[scratch_], xbase_lo);
1182 this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi);
1183 }
1184 }
1185 this->mulps(xdst_lo, xdst_lo);
1186 this->mulps(xdst_hi, xdst_hi);
1187 this->mulps(xdst_lo, xbase_lo);
1188 this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3;
1189 this->sqrtps(xdst_lo, xdst_lo);
1190 this->sqrtps(xdst_hi, xdst_hi);
1191 this->sqrtps(xdst_lo, xdst_lo);
1192 this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75
1193 this->movaps(xtmp_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1194 this->movaps(xtmp_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1195 this->divps(xtmp_lo, xdst_lo);
1196 this->divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75
1197 this->movaps(xdst_lo, xtmp_lo);
1198 this->movaps(xdst_hi, xtmp_hi);
1199
1200 if (tail != 0) {
1201 nchw_tail_sse41(tail, dst_, xdst_lo, xdst_hi);
1202 } else {
1203 this->movups(this->ptr[dst_], xdst_lo);
1204 this->movups(this->ptr[dst_ + 4 * sizeof(float)], xdst_hi);
1205 }
1206
1207 this->movaps(xa_lo, this->ptr[store_addr_ + 2 * 4 * sizeof(float)]);
1208 this->movaps(xa_hi, this->ptr[store_addr_ + 3 * 4 * sizeof(float)]);
1209 this->mulps(xa_lo, xa_lo);
1210 this->mulps(xa_hi, xa_hi);
1211 this->subps(xsum_lo, xa_lo);
1212 this->subps(xsum_hi, xa_hi);
1213
1214 // xa <- xb
1215 this->movaps(xb_lo, this->ptr[store_addr_ + 4 * 4 * sizeof(float)]);
1216 this->movaps(xb_hi, this->ptr[store_addr_ + 5 * 4 * sizeof(float)]);
1217 this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xb_lo);
1218 this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xb_hi);
1219
1220 // xb <- xc
1221 this->movaps(xc_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]);
1222 this->movaps(xc_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]);
1223 this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xc_lo);
1224 this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xc_hi);
1225
1226 // xc <- xd
1227 this->movaps(xd_lo, this->ptr[store_addr_ + 8 * 4 * sizeof(float)]);
1228 this->movaps(xd_hi, this->ptr[store_addr_ + 9 * 4 * sizeof(float)]);
1229 this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xd_lo);
1230 this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xd_hi);
1231
1232 // xd <- xe
1233 this->movaps(xe_lo, this->ptr[store_addr_ + 10 * 4 * sizeof(float)]);
1234 this->movaps(xe_hi, this->ptr[store_addr_ + 11 * 4 * sizeof(float)]);
1235 this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xe_lo);
1236 this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xe_hi);
1237}
1238
1239template <cpu_isa_t isa, data_type_t d_type>
1240void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body_sse41(int tail, int HW,
1241 prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo,
1242 Xbyak::Xmm xsum_hi) {}
1243
1244template <cpu_isa_t isa, data_type_t d_type>
1245jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t(
1246 const nchw_across_t &J, float A, float K, prop_kind_t pk,
1247 void *code_ptr, size_t code_size)
1248 : Base(code_ptr, code_size, jit_name())
1249 , config_(lrn_config_t::nchw_across)
1250 , nchw_across_(J)
1251 , alpha_(A)
1252 , k_(K)
1253 , pk_(pk) {}
1254
1255template <cpu_isa_t isa, data_type_t d_type>
1256void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw_across_t &J) {
1257 static const uint32_t mask[]
1258 = {0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
1259 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0};
1260 const Xbyak::Reg64 &c = this->r10;
1261 const Xbyak::Ymm &ymask = this->ymm2;
1262 const Xbyak::Ymm &ye = this->ymm3;
1263 const Xbyak::Ymm &ya = this->ymm4;
1264 const Xbyak::Ymm &yb = this->ymm5;
1265 const Xbyak::Ymm &yc = this->ymm6;
1266 const Xbyak::Ymm &yd = this->ymm7;
1267 const Xbyak::Ymm &ysum = this->ymm8;
1268
1269 this->preamble();
1270 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
1271
1272 if (J.tail != 0) {
1273 this->mov(
1274 this->imm_addr64_, reinterpret_cast<size_t>(&mask[7 - J.tail]));
1275 this->vmovups(ymask, this->ptr[this->imm_addr64_]);
1276 }
1277 this->mov(this->imm_addr64_, float2int(this->alpha_));
1278 this->vmovq(xalpha_, this->imm_addr64_);
1279 this->vbroadcastss(valpha_, xalpha_);
1280
1281 this->mov(this->imm_addr64_, float2int(this->k_));
1282 this->vmovq(xk_, this->imm_addr64_);
1283 this->vbroadcastss(yk_, xk_);
1284
1285 this->mov(src_, this->ptr[this->param1 + 0]);
1286 this->mov(dst_, this->ptr[this->param1 + 8]);
1287 if (pk_ != prop_kind::forward_inference)
1288 this->mov(scratch_, this->ptr[this->param1 + 16]);
1289
1290 this->vxorps(ya, ya, ya);
1291 this->vxorps(yb, yb, yb);
1292 if (J.tail != 0)
1293 this->vmaskmovps(yc, ymask, this->ptr[src_ + J.HW * 0]);
1294 else
1295 this->vmovups(yc, this->ptr[src_ + J.HW * 0]);
1296 if (J.tail != 0)
1297 this->vmaskmovps(yd, ymask, this->ptr[src_ + J.HW * 4]);
1298 else
1299 this->vmovups(yd, this->ptr[src_ + J.HW * 4]);
1300
1301 this->vxorps(ysum, ysum, ysum);
1302 this->vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
1303 this->vfmadd231ps(ysum, yd, yd);
1304
1305 this->mov(c, J.C - 2);
1306 Label lrn_loop;
1307 this->L(lrn_loop);
1308
1309 if (J.tail != 0)
1310 this->vmaskmovps(ye, ymask, this->ptr[src_ + J.HW * 8]);
1311 else
1312 this->vmovups(ye, this->ptr[src_ + J.HW * 8]);
1313
1314 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1315
1316 this->add(src_, J.HW * 4);
1317 this->add(dst_, J.HW * 4);
1318 if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1319 this->dec(c);
1320 this->cmp(c, 0);
1321 this->jne(lrn_loop, this->T_NEAR);
1322
1323 this->vxorps(ye, ye, ye);
1324
1325 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1326 this->add(src_, J.HW * 4);
1327 this->add(dst_, J.HW * 4);
1328 if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4);
1329
1330 nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum);
1331
1332 this->postamble();
1333}
1334
1335template <cpu_isa_t isa, data_type_t d_type>
1336jit_uni_lrn_fwd_kernel_t<isa, d_type>::~jit_uni_lrn_fwd_kernel_t() = default;
1337
1338template <>
1339jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::
1340 jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K,
1341 prop_kind_t pk, void *code_ptr, size_t code_size)
1342 : Base(code_ptr, code_size, jit_name())
1343 , config_(lrn_config_t::nchw_across)
1344 , nchw_across_(J)
1345 , alpha_(A)
1346 , k_(K)
1347 , pk_(pk) {}
1348
1349template <>
1350void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate(
1351 const nchw_across_t &J) {
1352
1353 /* Load from within the memory boundary of 'src_' and apply a zero-mask to
1354 * the 'x_hi' register:
1355 * block: src_ |tail = 3
1356 * src_: [x,x,x,x|a,b,c]
1357 * x_hi: [x,a,b,c]
1358 * mask: [0,1,1,1]
1359 * (...) -- data -- | -- illegal reading -- (...)
1360 * ^ memory boundary
1361 *
1362 * 'x_lo' is loaded with the elements between 'src_' and 'x_hi' when
1363 * tail.size is between [5:7]. The register is then left-shifted to
1364 * clear the overlapping elements with 'x_hi'.
1365 * block: - src_ - | tail = 7
1366 * src_: (...) [x,|a,b,c,d,e,f,g]
1367 * x_hi [d,e,f,g]
1368 * x_lo [a,b,c,d]
1369 * x_lo >> 1: [0,a,b,c]
1370 * (...) -- data -- | -- illegal reading -- (...)
1371 * ^ memory boundary
1372 *
1373 * - seg-fault happens if read occurs anywhere outside the
1374 * memory boundary.
1375 * */
1376 static const uint32_t mask[]
1377 = {0, 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff};
1378 assert(J.HW > 3);
1379
1380 const Xbyak::Reg64 &c = r10;
1381
1382 // unused: xmm2
1383 const Xbyak::Xmm &xmask_hi = this->xmm3;
1384 const Xbyak::Xmm &xsum_lo = this->xmm4;
1385 const Xbyak::Xmm &xsum_hi = this->xmm5;
1386 const Xbyak::Xmm &xa_lo = this->xmm6;
1387 const Xbyak::Xmm &xa_hi = this->xmm7;
1388 const Xbyak::Xmm &xb_lo = this->xmm8;
1389 const Xbyak::Xmm &xb_hi = this->xmm9;
1390 const Xbyak::Xmm &xc_lo = this->xmm10;
1391 const Xbyak::Xmm &xc_hi = this->xmm11;
1392 const Xbyak::Xmm &xd_lo = this->xmm12;
1393 const Xbyak::Xmm &xd_hi = this->xmm13;
1394 const Xbyak::Xmm &xe_lo = this->xmm14;
1395 const Xbyak::Xmm &xe_hi = this->xmm15;
1396
1397 const int vlen = cpu_isa_traits<sse41>::vlen / sizeof(float);
1398
1399 bool compute_tail = J.tail != 0;
1400 bool load_lo = J.tail == 0 || J.tail > 4;
1401
1402 size_t h_offset = vlen;
1403 size_t l_shift = 0;
1404
1405 this->preamble();
1406 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
1407
1408 this->mov(src_, this->ptr[this->param1 + 0]);
1409 this->mov(dst_, this->ptr[this->param1 + 8]);
1410 if (pk_ != prop_kind::forward_inference)
1411 this->mov(scratch_, this->ptr[this->param1 + 16]);
1412
1413 this->sub(rsp, stack_space_needed_);
1414 this->mov(store_addr_, rsp);
1415 this->and_(store_addr_, -15);
1416
1417 this->mov(this->imm_addr64_, float2int(this->alpha_));
1418 this->movq(xalpha_, this->imm_addr64_);
1419 this->shufps(xalpha_, xalpha_, 0);
1420
1421 this->mov(this->imm_addr64_, float2int(this->k_));
1422 this->movq(xk_, this->imm_addr64_);
1423 this->shufps(xk_, xk_, 0);
1424
1425 // put alpha_ and k_ into store (free up regs)
1426 this->movaps(this->ptr[store_addr_ + 0 * 4 * sizeof(float)], xalpha_);
1427 this->movaps(this->ptr[store_addr_ + 1 * 4 * sizeof(float)], xk_);
1428
1429 if (compute_tail) {
1430 assert(J.tail > 0 && J.tail < 2 * vlen);
1431 h_offset = J.tail - vlen;
1432 l_shift = nstl::min(2 * vlen - J.tail, vlen);
1433
1434 /* if 'tail' is between [1:3], need to zero-mask for underflow */
1435 size_t m_off = nstl::min(J.tail - 1, 3);
1436 this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[m_off]));
1437 this->movups(xmask_hi, this->ptr[this->imm_addr64_]);
1438 }
1439 // init xa, xb
1440 this->xorps(xa_lo, xa_lo);
1441 this->xorps(xa_hi, xa_hi);
1442 this->xorps(xb_lo, xb_lo);
1443 this->xorps(xb_hi, xb_hi);
1444
1445 // read xc, xd
1446 if (load_lo) this->movups(xc_lo, this->ptr[src_ + J.HW * 0]);
1447 this->movups(xc_hi, this->ptr[src_ + J.HW * 0 + h_offset * sizeof(float)]);
1448 if (compute_tail) {
1449 this->pslldq(xc_lo, l_shift * sizeof(float));
1450 this->andps(xc_hi, xmask_hi);
1451 }
1452
1453 if (load_lo) this->movups(xd_lo, this->ptr[src_ + J.HW * 4]);
1454 this->movups(xd_hi, this->ptr[src_ + J.HW * 4 + h_offset * sizeof(float)]);
1455 if (compute_tail) {
1456 this->pslldq(xd_lo, l_shift * sizeof(float));
1457 this->andps(xd_hi, xmask_hi);
1458 }
1459
1460 // put xa, xb, xc, xd into store to free-up regs
1461 this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xa_lo);
1462 this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xa_hi);
1463 this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xb_lo);
1464 this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xb_hi);
1465 this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xc_lo);
1466 this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xc_hi);
1467 this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xd_lo);
1468 this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xd_hi);
1469
1470 this->xorps(xsum_lo, xsum_lo);
1471 this->xorps(xsum_hi, xsum_hi);
1472 this->mulps(xc_lo, xc_lo);
1473 this->mulps(xc_hi, xc_hi);
1474 this->addps(xsum_lo, xc_lo);
1475 this->addps(xsum_hi, xc_hi);
1476 this->mulps(xd_lo, xd_lo);
1477 this->mulps(xd_hi, xd_hi);
1478 this->addps(xsum_lo, xd_lo);
1479 this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
1480
1481 this->mov(c, J.C - 2);
1482 Label lrn_loop;
1483 this->L(lrn_loop);
1484
1485 if (load_lo) this->movups(xe_lo, this->ptr[src_ + J.HW * 8]);
1486 this->movups(xe_hi, this->ptr[src_ + J.HW * 8 + h_offset * sizeof(float)]);
1487 if (compute_tail) {
1488 this->pslldq(xe_lo, l_shift * sizeof(float));
1489 this->andps(xe_hi, xmask_hi);
1490 }
1491
1492 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1493
1494 this->add(src_, J.HW * 4);
1495 this->add(dst_, J.HW * 4);
1496 if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1497 this->dec(c);
1498 this->cmp(c, 0);
1499 this->jne(lrn_loop, this->T_NEAR);
1500
1501 this->xorps(xe_lo, xe_lo);
1502 this->xorps(xe_hi, xe_hi);
1503
1504 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1505 this->add(src_, J.HW * 4);
1506 this->add(dst_, J.HW * 4);
1507 if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4);
1508
1509 nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi);
1510
1511 this->add(rsp, stack_space_needed_);
1512
1513 this->postamble();
1514}
1515
1516//////////////////////////////////////////////////////////////////////////////
1517// backward kernel
1518template <cpu_isa_t isa, data_type_t d_type>
1519jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1520 const nchw8c_across_t &J, float A, float B, int use_h_parallel,
1521 void *code_ptr, size_t code_size)
1522 : Base(code_ptr, code_size, jit_name())
1523 , config_(lrn_config_t::nchw8c_across)
1524 , nchw8c_across_(J)
1525 , nalphabeta_(-2 * A * B)
1526 , use_h_parallelizm_(use_h_parallel) {}
1527
1528template <cpu_isa_t isa, data_type_t d_type>
1529void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) {
1530
1531 const Xbyak::Reg64 &t = this->rsp;
1532 const Xbyak::Reg64 &hw = this->r10;
1533 const Xbyak::Xmm &xsrc_prev = this->xmm1;
1534 const Xbyak::Xmm &xws_prev = this->xmm2;
1535 const Xbyak::Xmm &xdiffdst_prev = this->xmm3;
1536 const Xbyak::Ymm &ysrc = this->ymm4;
1537 const Xbyak::Ymm &yws = this->ymm5;
1538 const Xbyak::Ymm &ydiffdst = this->ymm6;
1539 const Xbyak::Xmm &xsrc_next = this->xmm7;
1540 const Xbyak::Xmm &xws_next = this->xmm8;
1541 const Xbyak::Xmm &xdiffdst_next = this->xmm9;
1542 const Xbyak::Ymm &ya = this->ymm10;
1543 const Xbyak::Xmm &xa = this->xmm10;
1544 const Xbyak::Ymm &yb = this->ymm11;
1545 const Xbyak::Ymm &yd = this->ymm12;
1546 const Xbyak::Ymm &ye = this->ymm13;
1547 const Xbyak::Ymm &ysum = this->ymm14;
1548 const Xbyak::Ymm &ydiffsrc = this->ymm15;
1549
1550 this->preamble();
1551 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
1552
1553#define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1554 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1555 this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1556 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1557 this->mov(bwd_intermediate_res_,
1558 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1559 this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1560#undef GET_OFF
1561
1562 this->sub(t, 64);
1563 this->mov(this->imm_addr64_, float2int(this->nalphabeta_));
1564 this->vmovq(xnalphabeta_, this->imm_addr64_);
1565 this->vbroadcastss(vnalphabeta_, xnalphabeta_);
1566
1567 bool is_single = J.version == 3;
1568 bool is_first = J.version == -1 || J.version == -2;
1569 bool is_last = J.version == +1 || J.version == -2;
1570
1571 if (is_first || is_single) {
1572 this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
1573 this->vmovups(this->ptr[t + 0], xsrc_prev);
1574 }
1575 if (is_last || is_single) {
1576 this->vxorps(xsrc_next, xsrc_next, xsrc_next);
1577 this->vmovups(this->ptr[t + 48], xsrc_next);
1578 }
1579 this->mov(hw, this->use_h_parallelizm_ ? J.W : J.H * J.W);
1580 Label lrn_loop;
1581 this->L(lrn_loop);
1582 {
1583 if (!is_first && !is_single) {
1584 this->vmovups(xws_prev, this->ptr[scratch_ - J.H * J.W * 32 + 16]);
1585 this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]);
1586 this->vmovups(
1587 xdiffdst_prev, this->ptr[diffdst_ - J.H * J.W * 32 + 16]);
1588 this->vmulps(xa, xws_prev, xws_prev);
1589 this->vmulps(xa, xa, xws_prev);
1590 this->vsqrtps(xa, xa);
1591 this->vsqrtps(xa, xa);
1592 this->vmulps(xa, xa, xws_prev);
1593 this->vdivps(xsrc_prev, xsrc_prev, xa);
1594 this->vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
1595 }
1596
1597 this->vmovups(ysrc, this->ptr[src_]);
1598 this->vmovups(yws, this->ptr[scratch_]);
1599 this->vmovups(ydiffdst, this->ptr[diffdst_]);
1600 this->vmulps(ya, yws, yws);
1601 this->vmulps(ya, ya, yws);
1602 this->vsqrtps(ya, ya);
1603 this->vsqrtps(ya, ya);
1604 this->vdivps(ydiffsrc, ydiffdst, ya);
1605 this->vdivps(ysum, ydiffsrc, yws);
1606 this->vmulps(ysum, ysum, ysrc);
1607
1608 if (!is_last && !is_single) {
1609 this->vmovups(xws_next, this->ptr[scratch_ + J.H * J.W * 32]);
1610 this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]);
1611 this->vmovups(xdiffdst_next, this->ptr[diffdst_ + J.H * J.W * 32]);
1612 this->vmulps(xa, xws_next, xws_next);
1613 this->vmulps(xa, xa, xws_next);
1614 this->vsqrtps(xa, xa);
1615 this->vsqrtps(xa, xa);
1616 this->vmulps(xa, xa, xws_next);
1617 this->vdivps(xsrc_next, xsrc_next, xa);
1618 this->vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
1619 }
1620
1621 if (!is_first && !is_single)
1622 this->vmovups(this->ptr[t + 0], xdiffdst_prev);
1623 this->vmovups(this->ptr[t + 16], ysum);
1624 if (!is_last && !is_single)
1625 this->vmovups(this->ptr[t + 48], xdiffdst_next);
1626
1627 this->vmovups(ya, this->ptr[t + 16 - 8]);
1628 this->vmovups(yb, this->ptr[t + 16 - 4]);
1629 this->vaddps(ysum, ysum, ya);
1630 this->vmulps(ysrc, ysrc, vnalphabeta_);
1631 this->vaddps(ysum, ysum, yb);
1632
1633 this->vmovups(yd, this->ptr[t + 16 + 4]);
1634 this->vmovups(ye, this->ptr[t + 16 + 8]);
1635 this->vaddps(ysum, ysum, yd);
1636 this->vaddps(ysum, ysum, ye);
1637
1638 this->vfmadd231ps(ydiffsrc, ysum, ysrc);
1639
1640 this->vmovups(this->ptr[diffsrc_], ydiffsrc);
1641
1642 this->add(src_, 32);
1643 this->add(diffsrc_, 32);
1644 this->add(diffdst_, 32);
1645 this->add(scratch_, 32);
1646
1647 this->dec(hw);
1648 this->cmp(hw, 0);
1649 this->jne(lrn_loop, this->T_NEAR);
1650 }
1651
1652 this->add(t, 64);
1653 this->postamble();
1654}
1655
1656template <cpu_isa_t isa, data_type_t d_type>
1657jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t(
1658 const within_config_t &config, float A, float B, void *code_ptr,
1659 size_t code_size)
1660 : Base(config, code_ptr, code_size, jit_name())
1661 , config_(lrn_config_t::within_config)
1662 , within_config_(config)
1663 , nalphabeta_(-2.0f * A * B) {}
1664
1665template <cpu_isa_t isa, data_type_t d_type>
1666void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(
1667 const within_config_t &config) {
1668
1669 this->preamble();
1670 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
1671
1672#define GET_OFF(field) offsetof(jit_args_bwd_t, field)
1673 this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]);
1674 this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]);
1675 this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]);
1676 this->mov(bwd_intermediate_res_,
1677 this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]);
1678 this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]);
1679#undef GET_OFF
1680 this->load_constant(nalphabeta_, vnalphabeta_, xnalphabeta_);
1681
1682 static const int max_reg_blocks = is_superset(isa, avx512_core) ? 3 : 1;
1683 this->within_loop(config, max_reg_blocks, prop_kind::backward);
1684
1685 this->postamble();
1686}
1687
1688template <cpu_isa_t isa, data_type_t d_type>
1689void jit_uni_lrn_bwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff,
1690 int woff, int Woff, int stride, prop_kind_t pk, const int reg_block,
1691 int pixel_offset) {
1692
1693 static const std::array<Vmm, 3> vsum {{Vmm(1), Vmm(9), Vmm(18)}};
1694 static const std::array<std::array<Vmm, 3>, 3> diff_dst {{
1695 {{Vmm(2), Vmm(3), Vmm(6)}},
1696 {{Vmm(10), Vmm(11), Vmm(23)}},
1697 {{Vmm(19), Vmm(20), Vmm(26)}},
1698 }};
1699 static const std::array<std::array<Vmm, 3>, 3> ws1 {{
1700 {{Vmm(4), Vmm(5), Vmm(15)}},
1701 {{Vmm(12), Vmm(13), Vmm(27)}},
1702 {{Vmm(21), Vmm(22), Vmm(28)}},
1703 }};
1704 static const std::array<Vmm, 3> ws0 = !this->emulate_bfloat_
1705 ? std::array<Vmm, 3> {{Vmm(29), Vmm(30), Vmm(31)}}
1706 : std::array<Vmm, 3> {{Vmm(6), Vmm(15), Vmm(23)}};
1707 static const std::array<Vmm, 3> src {{Vmm(7), Vmm(16), Vmm(24)}};
1708 static const std::array<Vmm, 3> a {{Vmm(8), Vmm(17), Vmm(25)}};
1709
1710 static const std::size_t used_tmp_regs
1711 = this->emulate_bfloat_ ? ws1[0].size() - 1 : ws1[0].size();
1712
1713 IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb]));
1714 for (int i = hoff; i <= Hoff; ++i) {
1715 for (int j = woff; j <= Woff; ++j) {
1716 const auto idx = this->tempIdx_ % used_tmp_regs;
1717 IRB_LOOP(this->load_data(diff_dst[irb][idx],
1718 this->ptr[(diffdst_ + pixel_offset + irb_off)
1719 + (i * stride + j) * this->single_pixel_offset_]));
1720 IRB_LOOP(this->load_data(ws1[irb][idx],
1721 this->ptr[(bwd_intermediate_res_ + pixel_offset + irb_off)
1722 + (i * stride + j) * this->single_pixel_offset_]));
1723
1724 if (i == 0 && j == 0) {
1725 if (utils::one_of(d_type, data_type::bf16, data_type::f16)) {
1726 IRB_LOOP(this->load_data(ws0[irb],
1727 this->ptr[(scratch_ + pixel_offset + irb_off)]));
1728 IRB_LOOP(
1729 this->vdivps(a[irb], diff_dst[irb][idx], ws0[irb]));
1730 } else {
1731 IRB_LOOP(this->vdivps(a[irb], diff_dst[irb][idx],
1732 this->ptr[(scratch_ + pixel_offset + irb_off)]));
1733 }
1734 }
1735
1736 IRB_LOOP(this->vfmadd231ps(
1737 vsum[irb], ws1[irb][idx], diff_dst[irb][idx]));
1738 ++(this->tempIdx_);
1739 }
1740 }
1741
1742 this->tempIdx_ = this->tempIdx_ % used_tmp_regs;
1743
1744 if (utils::one_of(d_type, data_type::bf16, data_type::f16)) {
1745 IRB_LOOP(this->load_data(
1746 src[irb], this->ptr[(src_ + pixel_offset + irb_off)]));
1747 IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_, src[irb]));
1748 } else {
1749 IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_,
1750 this->ptr[(src_ + pixel_offset + irb_off)]));
1751 }
1752
1753 IRB_LOOP(this->vfmadd231ps(a[irb], src[irb], vsum[irb]));
1754
1755 IRB_LOOP(this->store_data(
1756 this->ptr[diffsrc_ + pixel_offset + irb_off], a[irb]));
1757
1758 if (is_superset(isa, avx512_core))
1759 this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1;
1760}
1761
1762template <cpu_isa_t isa, data_type_t d_type>
1763void jit_uni_lrn_bwd_kernel_t<isa, d_type>::move_data_pointers(
1764 int pixel_count, prop_kind_t pk) {
1765 const int pixel_offset = this->single_pixel_offset_ * pixel_count;
1766 this->add(src_, pixel_offset);
1767 this->add(diffsrc_, pixel_offset);
1768 this->add(diffdst_, pixel_offset);
1769 this->add(scratch_, pixel_offset);
1770 this->add(bwd_intermediate_res_, pixel_offset);
1771}
1772
1773template class jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>;
1774template class jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1775template class jit_uni_lrn_fwd_kernel_t<avx512_core,
1776 dnnl::impl::data_type::f32>;
1777template class jit_uni_lrn_fwd_kernel_t<avx512_core,
1778 dnnl::impl::data_type::bf16>;
1779template class jit_uni_lrn_fwd_kernel_t<avx512_core_fp16,
1780 dnnl::impl::data_type::f16>;
1781
1782template class jit_uni_lrn_kernel_t<
1783 jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>>;
1784template class jit_uni_lrn_kernel_t<
1785 jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1786template class jit_uni_lrn_kernel_t<
1787 jit_uni_lrn_fwd_kernel_t<avx512_core, dnnl::impl::data_type::f32>>;
1788template class jit_uni_lrn_kernel_t<
1789 jit_uni_lrn_fwd_kernel_t<avx512_core, dnnl::impl::data_type::bf16>>;
1790template class jit_uni_lrn_kernel_t<
1791 jit_uni_lrn_fwd_kernel_t<avx512_core_fp16, dnnl::impl::data_type::f16>>;
1792
1793template class jit_uni_lrn_bwd_kernel_t<avx512_core_fp16,
1794 dnnl::impl::data_type::f16>;
1795template class jit_uni_lrn_bwd_kernel_t<avx512_core,
1796 dnnl::impl::data_type::f32>;
1797template class jit_uni_lrn_bwd_kernel_t<avx512_core,
1798 dnnl::impl::data_type::bf16>;
1799template class jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>;
1800
1801template class jit_uni_lrn_kernel_t<
1802 jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>>;
1803template class jit_uni_lrn_kernel_t<
1804 jit_uni_lrn_bwd_kernel_t<avx512_core, dnnl::impl::data_type::f32>>;
1805template class jit_uni_lrn_kernel_t<
1806 jit_uni_lrn_bwd_kernel_t<avx512_core, dnnl::impl::data_type::bf16>>;
1807template class jit_uni_lrn_kernel_t<
1808 jit_uni_lrn_bwd_kernel_t<avx512_core_fp16, dnnl::impl::data_type::f16>>;
1809
1810} // namespace x64
1811} // namespace cpu
1812} // namespace impl
1813} // namespace dnnl
1814
1815// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1816