1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <array> |
18 | #include <cmath> |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/utils.hpp" |
22 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
23 | #include "cpu/x64/lrn/jit_uni_lrn_kernel.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::format_tag; |
31 | |
32 | #define IRB_LOOP(statement) \ |
33 | if (1 == reg_block) { \ |
34 | const int irb_off = 0; \ |
35 | const int irb = this->reg_block_idx_ % vsum.size(); \ |
36 | statement; \ |
37 | MAYBE_UNUSED(irb_off); \ |
38 | } else { \ |
39 | for (int irb = 0; irb < reg_block; irb++) { \ |
40 | const int irb_off = irb * this->single_pixel_offset_; \ |
41 | statement; \ |
42 | MAYBE_UNUSED(irb_off); \ |
43 | } \ |
44 | } |
45 | |
46 | using namespace Xbyak; |
47 | |
48 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
49 | cpu_isa_t isa, data_type_t d_type> |
50 | jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t( |
51 | void *code_ptr, size_t code_size, const char *name) |
52 | : jit_generator(name, code_ptr, code_size, true, isa) |
53 | , emulate_bfloat_(isa == avx512_core |
54 | && d_type == dnnl::impl::data_type::bf16 |
55 | && !mayiuse(avx512_core_bf16)) |
56 | , bf16_emu_( |
57 | emulate_bfloat_ ? utils::make_unique<bf16_emulation_t>(this, |
58 | bf16_emu_reserv_1_, bf16_emu_reserv_2_, |
59 | bf16_emu_reserv_3_, bf16_emu_scratch_, bf16_emu_reserv_4_) |
60 | : nullptr) {} |
61 | |
62 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
63 | cpu_isa_t isa, data_type_t d_type> |
64 | jit_uni_lrn_kernel_t<Derived<isa, d_type>>::jit_uni_lrn_kernel_t( |
65 | const within_config_t &config, void *code_ptr, size_t code_size, |
66 | const char *name) |
67 | : jit_uni_lrn_kernel_t(code_ptr, code_size, name) { |
68 | if (config.dat_tag == nhwc) |
69 | single_pixel_offset_ |
70 | = config.C * sizeof(typename prec_traits<d_type>::type); |
71 | } |
72 | |
73 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
74 | cpu_isa_t isa, data_type_t d_type> |
75 | jit_uni_lrn_kernel_t<Derived<isa, d_type>>::~jit_uni_lrn_kernel_t() = default; |
76 | |
77 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
78 | cpu_isa_t isa, data_type_t d_type> |
79 | void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_loop( |
80 | const within_config_t &config, int max_reg_blocks, prop_kind_t pk) { |
81 | const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this); |
82 | |
83 | const int lower_bound = (config.size - 1) / 2, |
84 | upper_bound = config.size - lower_bound - 1; |
85 | |
86 | int pixel_count = 0; |
87 | |
88 | for (int i = 0; i < lower_bound; ++i) { |
89 | pixel_count = 0; |
90 | for (int j = 0; j < lower_bound; ++j) |
91 | derived_ptr->within_body(-i, upper_bound, -j, upper_bound, config.W, |
92 | pk, 1, pixel_count++ * this->single_pixel_offset_); |
93 | derived_ptr->move_data_pointers(pixel_count, pk); |
94 | |
95 | within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, -i, |
96 | upper_bound, -lower_bound, upper_bound, config.W, pk); |
97 | |
98 | pixel_count = 0; |
99 | for (int j = config.W - upper_bound; j < config.W; ++j) |
100 | derived_ptr->within_body(-i, upper_bound, -lower_bound, |
101 | config.W - 1 - j, config.W, pk, 1, |
102 | pixel_count++ * this->single_pixel_offset_); |
103 | derived_ptr->move_data_pointers(pixel_count, pk); |
104 | } |
105 | |
106 | this->mov(h_, config.H - config.size + 1); |
107 | Label lrn_loop_h; |
108 | this->L(lrn_loop_h); |
109 | pixel_count = 0; |
110 | for (int j = 0; j < lower_bound; ++j) |
111 | derived_ptr->within_body(-lower_bound, upper_bound, -j, upper_bound, |
112 | config.W, pk, 1, pixel_count++ * this->single_pixel_offset_); |
113 | derived_ptr->move_data_pointers(pixel_count, pk); |
114 | |
115 | within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, |
116 | -lower_bound, upper_bound, -lower_bound, upper_bound, config.W, pk); |
117 | |
118 | pixel_count = 0; |
119 | for (int j = config.W - upper_bound; j < config.W; ++j) |
120 | derived_ptr->within_body(-lower_bound, upper_bound, -lower_bound, |
121 | config.W - 1 - j, config.W, pk, 1, |
122 | pixel_count++ * this->single_pixel_offset_); |
123 | derived_ptr->move_data_pointers(pixel_count, pk); |
124 | |
125 | this->dec(h_); |
126 | this->cmp(h_, 0); |
127 | this->jne(lrn_loop_h, this->T_NEAR); |
128 | |
129 | for (int i = config.H - upper_bound; i < config.H; ++i) { |
130 | pixel_count = 0; |
131 | for (int j = 0; j < lower_bound; ++j) |
132 | derived_ptr->within_body(-lower_bound, config.H - 1 - i, -j, |
133 | upper_bound, config.W, pk, 1, |
134 | pixel_count++ * this->single_pixel_offset_); |
135 | derived_ptr->move_data_pointers(pixel_count, pk); |
136 | |
137 | within_body_reg_blocked(config.W - config.size + 1, max_reg_blocks, |
138 | -lower_bound, config.H - 1 - i, -lower_bound, upper_bound, |
139 | config.W, pk); |
140 | |
141 | pixel_count = 0; |
142 | for (int j = config.W - upper_bound; j < config.W; ++j) |
143 | derived_ptr->within_body(-lower_bound, config.H - 1 - i, |
144 | -lower_bound, config.W - 1 - j, config.W, pk, 1, |
145 | pixel_count++ * this->single_pixel_offset_); |
146 | derived_ptr->move_data_pointers(pixel_count, pk); |
147 | } |
148 | } |
149 | |
150 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
151 | cpu_isa_t isa, data_type_t d_type> |
152 | void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::within_body_reg_blocked( |
153 | int loop_count, int max_reg_blocks, int hoff, int Hoff, int woff, |
154 | int Woff, int stride, prop_kind_t pk) { |
155 | |
156 | const auto derived_ptr = static_cast<Derived<isa, d_type> *>(this); |
157 | Label reg_block_compute_loop; |
158 | |
159 | const auto res = std::div(loop_count, max_reg_blocks); |
160 | if (res.quot) { |
161 | this->mov(this->w_, res.quot); |
162 | this->L(reg_block_compute_loop); |
163 | derived_ptr->within_body( |
164 | hoff, Hoff, woff, Woff, stride, pk, max_reg_blocks, 0); |
165 | derived_ptr->move_data_pointers(max_reg_blocks, pk); |
166 | this->dec(this->w_); |
167 | this->cmp(this->w_, 0); |
168 | this->jne(reg_block_compute_loop, this->T_NEAR); |
169 | } |
170 | if (res.rem) { |
171 | derived_ptr->within_body( |
172 | hoff, Hoff, woff, Woff, stride, pk, res.rem, 0); |
173 | derived_ptr->move_data_pointers(res.rem, pk); |
174 | } |
175 | } |
176 | |
177 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
178 | cpu_isa_t isa, data_type_t d_type> |
179 | void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_data( |
180 | const Vmm ®, const Xbyak::Address &p) { |
181 | this->uni_vmovups(reg, p); |
182 | } |
183 | |
184 | template <typename Gen, typename Reg, typename Addr> |
185 | void load_bf16_data(Gen generator, const Reg ®, const Addr &p) { |
186 | generator->vpmovzxwd(reg, p); |
187 | generator->vpslld(reg, reg, 0x10); |
188 | } |
189 | |
190 | template <> |
191 | void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core, |
192 | dnnl::impl::data_type::bf16>>::load_data(const Vmm ®, |
193 | const Xbyak::Address &p) { |
194 | load_bf16_data(this, reg, p); |
195 | } |
196 | |
197 | template <> |
198 | void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core, |
199 | dnnl::impl::data_type::bf16>>::load_data(const Vmm ®, |
200 | const Xbyak::Address &p) { |
201 | load_bf16_data(this, reg, p); |
202 | } |
203 | |
204 | template <> |
205 | void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core_fp16, |
206 | dnnl::impl::data_type::f16>>::load_data(const Vmm ®, |
207 | const Xbyak::Address &p) { |
208 | vcvtph2ps(reg, p); |
209 | } |
210 | |
211 | template <> |
212 | void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core_fp16, |
213 | dnnl::impl::data_type::f16>>::load_data(const Vmm ®, |
214 | const Xbyak::Address &p) { |
215 | vcvtph2ps(reg, p); |
216 | } |
217 | |
218 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
219 | cpu_isa_t isa, data_type_t d_type> |
220 | void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::store_data( |
221 | const Xbyak::Address &addr, const Vmm ®) { |
222 | this->uni_vmovups(addr, reg); |
223 | } |
224 | |
225 | template <typename Gen, typename Bf16Emu> |
226 | void store_bf16_data( |
227 | Gen generator, Bf16Emu emu, const Xbyak::Address &addr, const Zmm &zr) { |
228 | const Ymm yr = Ymm(zr.getIdx()); |
229 | if (mayiuse(avx512_core_bf16)) |
230 | generator->vcvtneps2bf16(yr, zr); |
231 | else |
232 | emu->vcvtneps2bf16(yr, zr); |
233 | generator->vmovdqu16(addr, yr); |
234 | } |
235 | |
236 | template <> |
237 | void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core, |
238 | dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr, |
239 | const Zmm &zr) { |
240 | store_bf16_data(this, bf16_emu_.get(), addr, zr); |
241 | } |
242 | |
243 | template <> |
244 | void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core, |
245 | dnnl::impl::data_type::bf16>>::store_data(const Xbyak::Address &addr, |
246 | const Zmm &zr) { |
247 | store_bf16_data(this, bf16_emu_.get(), addr, zr); |
248 | } |
249 | |
250 | template <> |
251 | void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<avx512_core_fp16, |
252 | dnnl::impl::data_type::f16>>::store_data(const Xbyak::Address &addr, |
253 | const Zmm &zr) { |
254 | vcvtps2ph(addr, zr, _op_mxcsr); |
255 | } |
256 | |
257 | template <> |
258 | void jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<avx512_core_fp16, |
259 | dnnl::impl::data_type::f16>>::store_data(const Xbyak::Address &addr, |
260 | const Zmm &zr) { |
261 | vcvtps2ph(addr, zr, _op_mxcsr); |
262 | } |
263 | |
264 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
265 | cpu_isa_t isa, data_type_t d_type> |
266 | void jit_uni_lrn_kernel_t<Derived<isa, d_type>>::load_constant( |
267 | float constant, const Vmm &v_constant, const Xbyak::Xmm &x_constant) { |
268 | this->mov(this->imm_addr64_, float2int(constant)); |
269 | this->uni_vmovq(x_constant, this->imm_addr64_); |
270 | this->vbroadcastss(v_constant, x_constant); |
271 | } |
272 | |
273 | template <> |
274 | void jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<sse41, |
275 | dnnl::impl::data_type::f32>>::load_constant(float constant, |
276 | const Vmm &v_constant, const Xbyak::Xmm &x_constant) { |
277 | this->mov(this->imm_addr64_, float2int(constant)); |
278 | this->uni_vmovq(x_constant, this->imm_addr64_); |
279 | this->shufps(x_constant, x_constant, 0); |
280 | } |
281 | |
282 | ////////////////////////////////////////////////////////////////////////////// |
283 | // forward kernel |
284 | template <cpu_isa_t isa, data_type_t d_type> |
285 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff, |
286 | int woff, int Woff, int stride, prop_kind_t pk, const int reg_block, |
287 | int pixel_offset) { |
288 | |
289 | static const std::array<Vmm, 3> vsum {{Vmm(2), Vmm(11), Vmm(20)}}; |
290 | static const std::array<Vmm, 3> vsum2 {{Vmm(3), Vmm(12), Vmm(21)}}; |
291 | static const std::array<Vmm, 3> vdst {{Vmm(4), Vmm(13), Vmm(22)}}; |
292 | static const std::array<std::array<Vmm, 6u>, 3u> vtmp { |
293 | {{{Vmm(5), Vmm(6), Vmm(7), Vmm(8), Vmm(9), Vmm(14)}}, |
294 | {{Vmm(18), Vmm(15), Vmm(16), Vmm(17), Vmm(29), Vmm(30)}}, |
295 | {{Vmm(23), Vmm(24), Vmm(25), Vmm(26), Vmm(28), Vmm(31)}}}}; |
296 | static const std::array<Vmm, 3> vscratch = {{Vmm(10), Vmm(19), Vmm(27)}}; |
297 | static const std::size_t used_tmp_regs |
298 | = this->emulate_bfloat_ ? vtmp[0].size() - 2 : vtmp[0].size(); |
299 | |
300 | IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb])); |
301 | for (int i = hoff; i <= Hoff; ++i) { |
302 | for (int j = woff; j <= Woff; ++j) { |
303 | if (i == 0 && j == 0) { |
304 | IRB_LOOP(this->load_data( |
305 | vdst[irb], this->ptr[src_ + pixel_offset + irb_off])); |
306 | IRB_LOOP(this->vfmadd231ps(vsum[irb], vdst[irb], vdst[irb])); |
307 | } else { |
308 | const auto idx = this->tempIdx_ % used_tmp_regs; |
309 | IRB_LOOP(this->load_data(vtmp[irb][idx], |
310 | this->ptr[(src_ + pixel_offset + irb_off) |
311 | + (i * stride + j) |
312 | * this->single_pixel_offset_])); |
313 | IRB_LOOP(this->vfmadd231ps( |
314 | vsum[irb], vtmp[irb][idx], vtmp[irb][idx])); |
315 | ++(this->tempIdx_); |
316 | } |
317 | } |
318 | } |
319 | |
320 | this->tempIdx_ = this->tempIdx_ % used_tmp_regs; |
321 | |
322 | IRB_LOOP(this->vfmadd132ps( |
323 | vsum[irb], vk_, valpha_)); // ysum <- ysum*valpha_+yk_ |
324 | IRB_LOOP(this->vmovaps(vscratch[irb], vsum[irb])); |
325 | |
326 | IRB_LOOP(this->vmulps(vsum2[irb], vsum[irb], vsum[irb])); |
327 | IRB_LOOP(this->vmulps( |
328 | vsum[irb], vsum[irb], vsum2[irb])); // ysum = (ysum*valpha_+yk_)^3; |
329 | IRB_LOOP(this->vsqrtps(vsum[irb], vsum[irb])); |
330 | IRB_LOOP(this->vsqrtps( |
331 | vsum[irb], vsum[irb])); // ysum = (ysum*valpha_+yk_)^0.75 |
332 | IRB_LOOP(this->vdivps( |
333 | vdst[irb], vdst[irb], vsum[irb])); // ydst <- ydst / ysum |
334 | |
335 | if (pk_ != prop_kind::forward_inference) { |
336 | IRB_LOOP(this->store_data( |
337 | this->ptr[scratch_ + pixel_offset + irb_off], vsum[irb])); |
338 | IRB_LOOP(this->vdivps(vscratch[irb], vdst[irb], vscratch[irb])); |
339 | IRB_LOOP(this->store_data( |
340 | this->ptr[bwd_intermediate_res_ + pixel_offset + irb_off], |
341 | vscratch[irb])); |
342 | } |
343 | |
344 | IRB_LOOP(this->store_data( |
345 | this->ptr[dst_ + pixel_offset + irb_off], vdst[irb])); |
346 | |
347 | if (is_superset(isa, avx512_core)) |
348 | this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1; |
349 | } |
350 | |
351 | template <> |
352 | void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::within_body( |
353 | int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk, |
354 | int reg_block, int pixel_offset) { |
355 | |
356 | const Xbyak::Xmm &xtmp_lo = this->xmm2; |
357 | const Xbyak::Xmm &xtmp_hi = this->xmm3; |
358 | const Xbyak::Xmm &xsum_lo = this->xmm4; |
359 | const Xbyak::Xmm &xsum_hi = this->xmm5; |
360 | const Xbyak::Xmm &xdst_lo = this->xmm6; |
361 | const Xbyak::Xmm &xdst_hi = this->xmm7; |
362 | const Xbyak::Xmm &xsum2_lo = this->xmm8; |
363 | const Xbyak::Xmm &xsum2_hi = this->xmm9; |
364 | |
365 | xorps(xsum_lo, xsum_lo); |
366 | xorps(xsum_hi, xsum_hi); |
367 | for (int i = hoff; i <= Hoff; ++i) { |
368 | for (int j = woff; j <= Woff; ++j) { |
369 | if (i == 0 && j == 0) { |
370 | movups(xdst_lo, ptr[src_ + pixel_offset]); |
371 | movups(xdst_hi, ptr[src_ + pixel_offset + 4 * sizeof(float)]); |
372 | mulps(xdst_lo, xdst_lo); |
373 | mulps(xdst_hi, xdst_hi); |
374 | addps(xsum_lo, xdst_lo); |
375 | addps(xsum_hi, xdst_hi); |
376 | } else { |
377 | movups(xtmp_lo, |
378 | ptr[src_ + pixel_offset |
379 | + (i * stride + j) * single_pixel_offset_]); |
380 | movups(xtmp_hi, |
381 | ptr[src_ + pixel_offset |
382 | + (i * stride + j) * single_pixel_offset_ |
383 | + 4 * sizeof(float)]); |
384 | this->mulps(xtmp_lo, xtmp_lo); |
385 | this->mulps(xtmp_hi, xtmp_hi); |
386 | this->addps(xsum_lo, xtmp_lo); |
387 | this->addps(xsum_hi, xtmp_hi); |
388 | } |
389 | } |
390 | } |
391 | this->mulps(xsum_lo, xalpha_); |
392 | this->mulps(xsum_hi, xalpha_); |
393 | this->addps(xsum_lo, xk_); |
394 | this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_ |
395 | this->movaps(xtmp_lo, xsum_lo); |
396 | this->movaps(xtmp_hi, xsum_hi); |
397 | if (pk_ != prop_kind::forward_inference) { |
398 | this->movups(this->ptr[scratch_ + pixel_offset], xtmp_lo); |
399 | this->movups(this->ptr[scratch_ + pixel_offset + 4 * sizeof(float)], |
400 | xtmp_hi); |
401 | } |
402 | this->movaps(xsum2_lo, xsum_lo); |
403 | this->movaps(xsum2_hi, xsum_hi); |
404 | this->mulps(xsum2_lo, xsum_lo); |
405 | this->mulps(xsum2_hi, xsum_hi); |
406 | this->mulps(xsum_lo, xsum2_lo); |
407 | this->mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha_+xk_)^3; |
408 | |
409 | this->sqrtps(xsum_lo, xsum_lo); |
410 | this->sqrtps(xsum_hi, xsum_hi); |
411 | this->sqrtps(xsum_lo, xsum_lo); |
412 | this->sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha_+xk_)^0.75 |
413 | |
414 | this->movups(xdst_lo, this->ptr[src_ + pixel_offset]); |
415 | this->movups(xdst_hi, this->ptr[src_ + pixel_offset + 4 * sizeof(float)]); |
416 | this->divps(xdst_lo, xsum_lo); |
417 | this->divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum |
418 | |
419 | this->movups(this->ptr[dst_ + pixel_offset], xdst_lo); |
420 | this->movups(this->ptr[dst_ + pixel_offset + 4 * sizeof(float)], xdst_hi); |
421 | } |
422 | |
423 | template <cpu_isa_t isa, data_type_t d_type> |
424 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::move_data_pointers( |
425 | int pixel_count, prop_kind_t pk) { |
426 | |
427 | const int pixel_offset = this->single_pixel_offset_ * pixel_count; |
428 | this->add(src_, pixel_offset); |
429 | this->add(dst_, pixel_offset); |
430 | if (pk_ != prop_kind::forward_inference) { |
431 | this->add(scratch_, pixel_offset); |
432 | this->add(bwd_intermediate_res_, pixel_offset); |
433 | } |
434 | } |
435 | |
436 | template <cpu_isa_t isa, data_type_t d_type> |
437 | jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t( |
438 | const within_config_t &config, float A, float K, prop_kind_t pk, |
439 | void *code_ptr, size_t code_size) |
440 | : Base(config, code_ptr, code_size, jit_name()) |
441 | , config_(lrn_config_t::within_config) |
442 | , within_config_(config) |
443 | , alpha_(A) |
444 | , k_(K) |
445 | , pk_(pk) {} |
446 | |
447 | template <cpu_isa_t isa, data_type_t d_type> |
448 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate( |
449 | const within_config_t &config) { |
450 | this->preamble(); |
451 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
452 | |
453 | #define GET_OFF(field) offsetof(jit_args_fwd_t, field) |
454 | this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]); |
455 | this->mov(dst_, this->ptr[this->param1 + GET_OFF(dst)]); |
456 | if (pk_ != prop_kind::forward_inference) { |
457 | this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]); |
458 | this->mov(bwd_intermediate_res_, |
459 | this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]); |
460 | } |
461 | #undef GET_OFF |
462 | |
463 | this->load_constant(alpha_, valpha_, xalpha_); |
464 | this->load_constant(k_, vk_, xk_); |
465 | |
466 | static const int max_reg_blocks = is_superset(isa, avx512_core) ? 3 : 1; |
467 | this->within_loop(config, max_reg_blocks, pk_); |
468 | |
469 | this->postamble(); |
470 | } |
471 | |
472 | template <cpu_isa_t isa, data_type_t d_type> |
473 | jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t( |
474 | const struct nchw8c_across_t &J, float A, float K, prop_kind_t pk, |
475 | void *code_ptr, size_t code_size) |
476 | : Base(code_ptr, code_size, jit_name()) |
477 | , config_(lrn_config_t::nchw8c_across) |
478 | , nchw8c_across_(J) |
479 | , alpha_(A) |
480 | , k_(K) |
481 | , pk_(pk) {} |
482 | |
483 | template <cpu_isa_t isa, data_type_t d_type> |
484 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) { |
485 | const Xbyak::Reg64 &t = this->rsp; |
486 | const Xbyak::Reg64 &hw = this->r9; |
487 | const Xbyak::Xmm &xsrc_prev = this->xmm2; |
488 | const Xbyak::Ymm &ysrc = this->ymm3; |
489 | const Xbyak::Ymm &yc = this->ymm3; |
490 | const Xbyak::Xmm &xsrc_next = this->xmm4; |
491 | const Xbyak::Ymm &ya = this->ymm5; |
492 | const Xbyak::Ymm &yb = this->ymm6; |
493 | const Xbyak::Ymm &yd = this->ymm7; |
494 | const Xbyak::Ymm &ye = this->ymm8; |
495 | const Xbyak::Ymm &ysum = this->ymm9; |
496 | const Xbyak::Ymm &ysum2 = this->ymm10; |
497 | const Xbyak::Ymm &ydst = this->ymm11; |
498 | const Xbyak::Ymm &ybase = this->ymm12; |
499 | |
500 | this->preamble(); |
501 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
502 | |
503 | this->mov(src_, this->ptr[this->param1 + 0]); |
504 | this->mov(dst_, this->ptr[this->param1 + 8]); |
505 | if (pk_ != prop_kind::forward_inference) |
506 | this->mov(scratch_, this->ptr[this->param1 + 16]); |
507 | this->sub(t, 64); |
508 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
509 | this->vmovq(xalpha_, this->imm_addr64_); |
510 | this->vbroadcastss(valpha_, xalpha_); |
511 | |
512 | this->mov(this->imm_addr64_, float2int(this->k_)); |
513 | this->vmovq(xk_, this->imm_addr64_); |
514 | this->vbroadcastss(yk_, xk_); |
515 | |
516 | if (J.version == -1) { |
517 | this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev); |
518 | this->vmovups(this->ptr[t + 0], xsrc_prev); |
519 | } |
520 | if (J.version == +1) { |
521 | this->vxorps(xsrc_next, xsrc_next, xsrc_next); |
522 | this->vmovups(this->ptr[t + 48], xsrc_next); |
523 | } |
524 | |
525 | this->mov(hw, J.H * J.W); |
526 | |
527 | Label lrn_loop; |
528 | this->L(lrn_loop); |
529 | |
530 | if (J.version != -1) |
531 | this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]); |
532 | this->vmovups(ysrc, this->ptr[src_]); |
533 | if (J.version != +1) |
534 | this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]); |
535 | |
536 | if (J.version != -1) this->vmovups(this->ptr[t + 0], xsrc_prev); |
537 | this->vmovups(this->ptr[t + 16], ysrc); |
538 | if (J.version != +1) this->vmovups(this->ptr[t + 48], xsrc_next); |
539 | |
540 | this->vmovups(ya, this->ptr[t + 16 - 8]); |
541 | this->vmovups(yb, this->ptr[t + 16 - 4]); |
542 | this->vmovups(yd, this->ptr[t + 16 + 4]); |
543 | this->vmovups(ye, this->ptr[t + 16 + 8]); |
544 | this->vmulps(ysum, yc, yc); |
545 | this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya |
546 | this->vfmadd231ps(ysum, yb, yb); |
547 | this->vfmadd231ps(ysum, yd, yd); |
548 | this->vfmadd231ps(ysum, ye, ye); |
549 | this->vfmadd132ps(ysum, yk_, valpha_); // ysum <- ysum*valpha_+yk_ |
550 | |
551 | this->vmovaps(ybase, ysum); |
552 | if (pk_ != prop_kind::forward_inference) |
553 | this->vmovups(this->ptr[scratch_], ybase); |
554 | this->vmulps(ysum2, ysum, ysum); |
555 | this->vmulps(ysum, ysum, ysum2); // ysum = ybase^3; |
556 | this->vsqrtps(ysum, ysum); |
557 | this->vsqrtps(ysum, ysum); // ysum = ybase^0.75 |
558 | this->vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum |
559 | this->vmovups(this->ptr[dst_], ydst); |
560 | |
561 | this->add(src_, 32); |
562 | this->add(dst_, 32); |
563 | if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32); |
564 | this->dec(hw); |
565 | this->cmp(hw, 0); |
566 | this->jne(lrn_loop, this->T_NEAR); |
567 | |
568 | this->add(t, 64); |
569 | this->postamble(); |
570 | } |
571 | |
572 | template <> |
573 | jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>:: |
574 | jit_uni_lrn_fwd_kernel_t(const struct nchw8c_across_t &J, float A, |
575 | float K, prop_kind_t pk, void *code_ptr, size_t code_size) |
576 | : Base(code_ptr, code_size, jit_name()) |
577 | , config_(lrn_config_t::nchw8c_across) |
578 | , nchw8c_across_(J) |
579 | , alpha_(A) |
580 | , k_(K) |
581 | , pk_(pk) {} |
582 | |
583 | template <> |
584 | void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate( |
585 | const nchw8c_across_t &J) { |
586 | |
587 | const Xbyak::Reg64 &t = this->rsp; |
588 | const Xbyak::Reg64 &hw = this->r9; |
589 | const Xbyak::Xmm &xsrc_lo = this->xmm2; |
590 | const Xbyak::Xmm &xsrc_hi = this->xmm3; |
591 | const Xbyak::Xmm &xc_lo = this->xmm4; |
592 | const Xbyak::Xmm &xc_hi = this->xmm5; |
593 | const Xbyak::Xmm &xsum_lo = xc_lo; |
594 | const Xbyak::Xmm &xsum_hi = xc_hi; |
595 | const Xbyak::Xmm &xsrc_prev = this->xmm6; |
596 | const Xbyak::Xmm &xsrc_next = this->xmm7; |
597 | const Xbyak::Xmm &xa_lo = this->xmm8; |
598 | const Xbyak::Xmm &xa_hi = this->xmm9; |
599 | const Xbyak::Xmm &xb_lo = this->xmm10; |
600 | const Xbyak::Xmm &xb_hi = this->xmm11; |
601 | const Xbyak::Xmm &xd_lo = this->xmm12; |
602 | const Xbyak::Xmm &xd_hi = this->xmm13; |
603 | const Xbyak::Xmm &xe_lo = this->xmm14; |
604 | const Xbyak::Xmm &xe_hi = this->xmm15; |
605 | const Xbyak::Xmm &xbase_lo = this->xmm14; |
606 | const Xbyak::Xmm &xbase_hi = this->xmm15; |
607 | |
608 | this->preamble(); |
609 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
610 | |
611 | this->mov(src_, this->ptr[this->param1 + 0]); |
612 | this->mov(dst_, this->ptr[this->param1 + 8]); |
613 | if (pk_ != prop_kind::forward_inference) |
614 | this->mov(scratch_, this->ptr[this->param1 + 16]); |
615 | this->sub(t, 64); |
616 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
617 | this->movq(xalpha_, this->imm_addr64_); |
618 | this->shufps(xalpha_, xalpha_, 0); |
619 | |
620 | this->mov(this->imm_addr64_, float2int(this->k_)); |
621 | this->movq(xk_, this->imm_addr64_); |
622 | this->shufps(xk_, xk_, 0); |
623 | |
624 | if (J.version == -1) { |
625 | this->xorps(xsrc_prev, xsrc_prev); |
626 | this->movups(this->ptr[t + 0], xsrc_prev); |
627 | } |
628 | if (J.version == +1) { |
629 | this->xorps(xsrc_next, xsrc_next); |
630 | this->movups(this->ptr[t + 48], xsrc_next); |
631 | } |
632 | |
633 | this->mov(hw, J.H * J.W); |
634 | Label lrn_loop; |
635 | L(lrn_loop); |
636 | |
637 | if (J.version != -1) |
638 | this->movups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]); |
639 | this->movups(xsrc_lo, this->ptr[src_]); |
640 | this->movups(xsrc_hi, this->ptr[src_ + 4 * sizeof(float)]); |
641 | if (J.version != +1) |
642 | this->movups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]); |
643 | |
644 | if (J.version != -1) this->movups(this->ptr[t + 0], xsrc_prev); |
645 | this->movups(this->ptr[t + 16], xsrc_lo); |
646 | this->movups(this->ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); |
647 | if (J.version != +1) this->movups(this->ptr[t + 48], xsrc_next); |
648 | |
649 | this->movups(xa_lo, this->ptr[t + 16 - 8]); |
650 | this->movups(xa_hi, this->ptr[t + 16 - 8 + 4 * sizeof(float)]); |
651 | this->movups(xb_lo, this->ptr[t + 16 - 4]); |
652 | this->movups(xb_hi, this->ptr[t + 16 - 4 + 4 * sizeof(float)]); |
653 | this->movups(xd_lo, this->ptr[t + 16 + 4]); |
654 | this->movups(xd_hi, this->ptr[t + 16 + 4 + 4 * sizeof(float)]); |
655 | this->movups(xe_lo, this->ptr[t + 16 + 8]); |
656 | this->movups(xe_hi, this->ptr[t + 16 + 8 + 4 * sizeof(float)]); |
657 | this->movaps(xc_lo, xsrc_lo); |
658 | this->movaps(xc_hi, xsrc_hi); |
659 | this->mulps(xsum_lo, xc_lo); |
660 | this->mulps(xsum_hi, xc_hi); |
661 | this->mulps(xa_lo, xa_lo); |
662 | this->mulps(xa_hi, xa_hi); |
663 | this->addps(xsum_lo, xa_lo); |
664 | this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa |
665 | this->mulps(xb_lo, xb_lo); |
666 | this->mulps(xb_hi, xb_hi); |
667 | this->addps(xsum_lo, xb_lo); |
668 | this->addps(xsum_hi, xb_hi); |
669 | this->mulps(xd_lo, xd_lo); |
670 | this->mulps(xd_hi, xd_hi); |
671 | this->addps(xsum_lo, xd_lo); |
672 | this->addps(xsum_hi, xd_hi); |
673 | this->mulps(xe_lo, xe_lo); |
674 | this->mulps(xe_hi, xe_hi); |
675 | this->addps(xsum_lo, xe_lo); |
676 | this->addps(xsum_hi, xe_hi); |
677 | |
678 | this->mulps(xsum_lo, xalpha_); |
679 | this->mulps(xsum_hi, xalpha_); |
680 | this->addps(xsum_lo, xk_); |
681 | this->addps(xsum_hi, xk_); // xsum <- xsum*xalpha_+xk_ |
682 | |
683 | this->movaps(xbase_lo, xsum_lo); |
684 | this->movaps(xbase_hi, xsum_hi); |
685 | if (pk_ != prop_kind::forward_inference) { |
686 | this->movups(this->ptr[scratch_], xbase_lo); |
687 | this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi); |
688 | } |
689 | this->mulps(xsum_lo, xsum_lo); |
690 | this->mulps(xsum_hi, xsum_hi); |
691 | this->mulps(xsum_lo, xbase_lo); |
692 | this->mulps(xsum_hi, xbase_hi); // xsum = xbase^3; |
693 | this->sqrtps(xsum_lo, xsum_lo); |
694 | this->sqrtps(xsum_hi, xsum_hi); |
695 | this->sqrtps(xsum_lo, xsum_lo); |
696 | this->sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 |
697 | this->divps(xsrc_lo, xsum_lo); |
698 | this->divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum |
699 | this->movups(this->ptr[dst_], xsrc_lo); |
700 | this->movups(this->ptr[dst_ + 4 * sizeof(float)], xsrc_hi); |
701 | |
702 | this->add(src_, 32); |
703 | this->add(dst_, 32); |
704 | if (pk_ != prop_kind::forward_inference) add(scratch_, 32); |
705 | this->dec(hw); |
706 | this->cmp(hw, 0); |
707 | this->jne(lrn_loop, this->T_NEAR); |
708 | |
709 | this->add(t, 64); |
710 | this->postamble(); |
711 | } |
712 | |
713 | template <cpu_isa_t isa, data_type_t d_type> |
714 | jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t( |
715 | const struct nhwc_across_t &J, float A, float K, prop_kind_t pk, |
716 | void *code_ptr, size_t code_size) |
717 | : Base(code_ptr, code_size, jit_name()) |
718 | , config_(lrn_config_t::nhwc_across) |
719 | , nhwc_across_(J) |
720 | , alpha_(A) |
721 | , k_(K) |
722 | , pk_(pk) {} |
723 | |
724 | template <cpu_isa_t isa, data_type_t d_type> |
725 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nhwc_across_t &J) { |
726 | static const uint32_t mask[] = {0, 0, 0x80000000, 0x80000000, 0x80000000, |
727 | 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0, 0}; |
728 | |
729 | const Xbyak::Reg64 &c = this->r9; |
730 | const Xbyak::Ymm &ya = this->ymm2; |
731 | const Xbyak::Ymm &yb = this->ymm3; |
732 | const Xbyak::Ymm &yc = this->ymm4; |
733 | const Xbyak::Ymm &yd = this->ymm5; |
734 | const Xbyak::Ymm &ye = this->ymm6; |
735 | const Xbyak::Ymm &ysum = this->ymm7; |
736 | const Xbyak::Ymm &ydst = this->ymm8; |
737 | const Xbyak::Ymm &ybase = this->ymm9; |
738 | const Xbyak::Ymm &ymask = this->ymm10; |
739 | |
740 | this->preamble(); |
741 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
742 | |
743 | this->mov(src_, this->ptr[this->param1 + 0]); |
744 | this->mov(dst_, this->ptr[this->param1 + 8]); |
745 | if (pk_ != prop_kind::forward_inference) |
746 | this->mov(scratch_, this->ptr[this->param1 + 16]); |
747 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
748 | this->vmovq(xalpha_, this->imm_addr64_); |
749 | this->vbroadcastss(valpha_, xalpha_); |
750 | |
751 | this->mov(this->imm_addr64_, float2int(this->k_)); |
752 | this->vmovq(xk_, this->imm_addr64_); |
753 | this->vbroadcastss(yk_, xk_); |
754 | |
755 | this->vxorps(ysum, ysum, ysum); |
756 | |
757 | this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[0])); |
758 | this->vmovups(ymask, this->ptr[this->imm_addr64_]); |
759 | this->vmaskmovps(ya, ymask, this->ptr[src_ - 8]); |
760 | this->vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 |
761 | |
762 | this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[1])); |
763 | this->vmovups(ymask, this->ptr[this->imm_addr64_]); |
764 | this->vmaskmovps(yb, ymask, this->ptr[src_ - 4]); |
765 | this->vfmadd231ps(ysum, yb, yb); |
766 | |
767 | this->mov(c, J.C / 8 - 1); |
768 | Label lrn_loop; |
769 | this->L(lrn_loop); |
770 | |
771 | this->vmovups(yc, this->ptr[src_]); |
772 | this->vmovups(yd, this->ptr[src_ + 4]); |
773 | this->vmovups(ye, this->ptr[src_ + 8]); |
774 | this->vfmadd231ps(ysum, yc, yc); |
775 | this->vfmadd231ps(ysum, yd, yd); |
776 | this->vfmadd231ps(ysum, ye, ye); |
777 | |
778 | this->vmovups(ydst, ysum); |
779 | this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_ |
780 | |
781 | this->vmovaps(ybase, ydst); |
782 | if (pk_ != prop_kind::forward_inference) |
783 | this->vmovups(this->ptr[scratch_], ybase); |
784 | this->vmulps(ydst, ydst, ydst); |
785 | this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3; |
786 | this->vsqrtps(ydst, ydst); |
787 | this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75 |
788 | |
789 | this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75 |
790 | this->vmovups(this->ptr[dst_], ydst); |
791 | |
792 | this->vxorps(ysum, ysum, ysum); |
793 | |
794 | this->add(src_, 32); |
795 | this->add(dst_, 32); |
796 | if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32); |
797 | |
798 | this->vmovups(ya, this->ptr[src_ - 8]); |
799 | this->vfmadd231ps(ysum, ya, ya); |
800 | this->vmovups(yb, this->ptr[src_ - 4]); |
801 | this->vfmadd231ps(ysum, yb, yb); |
802 | |
803 | this->dec(c); |
804 | this->cmp(c, 0); |
805 | this->jne(lrn_loop, this->T_NEAR); |
806 | |
807 | this->vmovups(yc, this->ptr[src_]); |
808 | this->vfmadd231ps(ysum, yc, yc); |
809 | |
810 | this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[2])); |
811 | this->vmovups(ymask, this->ptr[this->imm_addr64_]); |
812 | this->vmaskmovps(yd, ymask, this->ptr[src_ + 4]); |
813 | this->vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 |
814 | |
815 | this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[3])); |
816 | this->vmovups(ymask, this->ptr[this->imm_addr64_]); |
817 | this->vmaskmovps(ye, ymask, this->ptr[src_ + 8]); |
818 | this->vfmadd231ps(ysum, ye, ye); |
819 | |
820 | this->vmovups(ydst, ysum); |
821 | this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_ |
822 | |
823 | this->vmovaps(ybase, ydst); |
824 | if (pk_ != prop_kind::forward_inference) |
825 | this->vmovups(this->ptr[scratch_], ybase); |
826 | this->vmulps(ydst, ydst, ydst); |
827 | this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3; |
828 | this->vsqrtps(ydst, ydst); |
829 | this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75 |
830 | this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75 |
831 | |
832 | this->vmovups(this->ptr[dst_], ydst); |
833 | |
834 | this->postamble(); |
835 | } |
836 | |
837 | template <> |
838 | jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>:: |
839 | jit_uni_lrn_fwd_kernel_t(const struct nhwc_across_t &J, float A, |
840 | float K, prop_kind_t pk, void *code_ptr, size_t code_size) |
841 | : Base(code_ptr, code_size, jit_name()) |
842 | , config_(lrn_config_t::nhwc_across) |
843 | , nhwc_across_(J) |
844 | , alpha_(A) |
845 | , k_(K) |
846 | , pk_(pk) {} |
847 | |
848 | template <> |
849 | void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate( |
850 | const nhwc_across_t &J) { |
851 | static uint32_t store[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
852 | const Xbyak::Reg64 c = this->r9; |
853 | |
854 | const Xbyak::Xmm &xdst_lo = this->xmm0; |
855 | const Xbyak::Xmm &xdst_hi = this->xmm1; |
856 | const Xbyak::Xmm &xa_lo = this->xmm2; |
857 | const Xbyak::Xmm &xa_hi = this->xmm3; |
858 | const Xbyak::Xmm &xb_lo = this->xmm2; |
859 | const Xbyak::Xmm &xb_hi = this->xmm3; |
860 | const Xbyak::Xmm &xc_lo = this->xmm4; |
861 | const Xbyak::Xmm &xc_hi = this->xmm5; |
862 | const Xbyak::Xmm &xd_lo = this->xmm6; |
863 | const Xbyak::Xmm &xd_hi = this->xmm7; |
864 | const Xbyak::Xmm &xe_lo = this->xmm8; |
865 | const Xbyak::Xmm &xe_hi = this->xmm9; |
866 | const Xbyak::Xmm &xsum_lo = this->xmm10; |
867 | const Xbyak::Xmm &xsum_hi = this->xmm11; |
868 | // unused: xmm12, xmm13; |
869 | const Xbyak::Xmm &xbase_lo = this->xmm14; |
870 | const Xbyak::Xmm &xbase_hi = this->xmm15; |
871 | |
872 | this->preamble(); |
873 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
874 | |
875 | this->mov(src_, this->ptr[this->param1 + 0]); |
876 | this->mov(dst_, this->ptr[this->param1 + 8]); |
877 | if (pk_ != prop_kind::forward_inference) |
878 | mov(scratch_, this->ptr[this->param1 + 16]); |
879 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
880 | this->movq(xalpha_, this->imm_addr64_); |
881 | this->shufps(xalpha_, xalpha_, 0); |
882 | |
883 | this->mov(this->imm_addr64_, float2int(this->k_)); |
884 | this->movq(xk_, this->imm_addr64_); |
885 | this->shufps(xk_, xk_, 0); |
886 | |
887 | this->mov(store_addr_, reinterpret_cast<size_t>(&store[0])); |
888 | this->and_(store_addr_, -15); |
889 | this->movups(this->ptr[store_addr_], xalpha_); |
890 | this->movups(this->ptr[store_addr_ + 4 * sizeof(float)], xk_); |
891 | |
892 | this->xorps(xsum_lo, xsum_lo); |
893 | this->xorps(xsum_hi, xsum_hi); |
894 | |
895 | /* load the 2 first blocks of channels |
896 | * block: | -- low -- | -- hi -- | |
897 | * C: [c1,c2,c3,c4,c5,c6,c7,c8] |
898 | * xa_lo << 2 [0,0,c1,c2] |
899 | * xa_hi [c3,c4,c5,c6] |
900 | * xb_lo << 1 [0,c1,c2,c3] |
901 | * xb_hi [c4,c5,c6,c7] |
902 | * | -- data -- (...) |
903 | * ^ memory boundary |
904 | */ |
905 | this->movups(xa_lo, this->ptr[src_]); |
906 | this->movups(xa_hi, this->ptr[src_ + 2 * sizeof(float)]); |
907 | this->pslldq(xa_lo, 2 * sizeof(float)); |
908 | this->mulps(xa_lo, xa_lo); |
909 | this->mulps(xa_hi, xa_hi); |
910 | this->addps(xsum_lo, xa_lo); |
911 | this->addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 |
912 | |
913 | this->movups(xb_lo, this->ptr[src_]); |
914 | this->movups(xb_hi, this->ptr[src_ + 3 * sizeof(float)]); |
915 | this->pslldq(xb_lo, 1 * sizeof(float)); |
916 | this->mulps(xb_lo, xb_lo); |
917 | this->mulps(xb_hi, xb_hi); |
918 | this->addps(xsum_lo, xb_lo); |
919 | this->addps(xsum_hi, xb_hi); |
920 | |
921 | this->mov(c, J.C / 8 - 1); |
922 | Label lrn_loop; |
923 | this->L(lrn_loop); |
924 | |
925 | this->movups(xc_lo, this->ptr[src_]); |
926 | this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]); |
927 | this->movups(xd_lo, this->ptr[src_ + 4]); |
928 | this->movups(xd_hi, this->ptr[src_ + 4 + 4 * sizeof(float)]); |
929 | this->movups(xe_lo, this->ptr[src_ + 8]); |
930 | this->movups(xe_hi, this->ptr[src_ + 8 + 4 * sizeof(float)]); |
931 | this->mulps(xc_lo, xc_lo); |
932 | this->mulps(xc_hi, xc_hi); |
933 | this->addps(xsum_lo, xc_lo); |
934 | this->addps(xsum_hi, xc_hi); |
935 | this->mulps(xd_lo, xd_lo); |
936 | this->mulps(xd_hi, xd_hi); |
937 | this->addps(xsum_lo, xd_lo); |
938 | this->addps(xsum_hi, xd_hi); |
939 | this->mulps(xe_lo, xe_lo); |
940 | this->mulps(xe_hi, xe_hi); |
941 | this->addps(xsum_lo, xe_lo); |
942 | this->addps(xsum_hi, xe_hi); |
943 | |
944 | this->movaps(xdst_lo, xsum_lo); |
945 | this->movaps(xdst_hi, xsum_hi); |
946 | // xdst <- xsum*xalpha_+xk_ |
947 | this->mulps(xdst_lo, this->ptr[store_addr_]); |
948 | this->mulps(xdst_hi, this->ptr[store_addr_]); |
949 | this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]); |
950 | this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]); |
951 | |
952 | this->movaps(xbase_lo, xdst_lo); |
953 | this->movaps(xbase_hi, xdst_hi); |
954 | if (pk_ != prop_kind::forward_inference) { |
955 | this->movups(this->ptr[scratch_], xbase_lo); |
956 | this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi); |
957 | } |
958 | this->mulps(xdst_lo, xdst_lo); |
959 | this->mulps(xdst_hi, xdst_hi); |
960 | this->mulps(xdst_lo, xbase_lo); |
961 | this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3; |
962 | this->sqrtps(xdst_lo, xdst_lo); |
963 | this->sqrtps(xdst_hi, xdst_hi); |
964 | this->sqrtps(xdst_lo, xdst_lo); |
965 | this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75 |
966 | |
967 | this->movups(xc_lo, this->ptr[src_]); |
968 | this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]); |
969 | this->divps(xc_lo, xdst_lo); |
970 | this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75 |
971 | this->movups(this->ptr[dst_], xc_lo); |
972 | this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi); |
973 | |
974 | this->xorps(xsum_lo, xsum_lo); |
975 | this->xorps(xsum_hi, xsum_hi); |
976 | |
977 | this->add(src_, 32); |
978 | this->add(dst_, 32); |
979 | if (pk_ != prop_kind::forward_inference) this->add(scratch_, 32); |
980 | |
981 | this->movups(xa_lo, this->ptr[src_ - 8]); |
982 | this->movups(xa_hi, this->ptr[src_ - 8 + 4 * sizeof(float)]); |
983 | this->mulps(xa_lo, xa_lo); |
984 | this->mulps(xa_hi, xa_hi); |
985 | this->addps(xsum_lo, xa_lo); |
986 | this->addps(xsum_hi, xa_hi); |
987 | this->movups(xb_lo, this->ptr[src_ - 4]); |
988 | this->movups(xb_hi, this->ptr[src_ - 4 + 4 * sizeof(float)]); |
989 | this->mulps(xb_lo, xb_lo); |
990 | this->mulps(xb_hi, xb_hi); |
991 | this->addps(xsum_lo, xb_lo); |
992 | this->addps(xsum_hi, xb_hi); |
993 | |
994 | this->dec(c); |
995 | this->cmp(c, 0); |
996 | this->jne(lrn_loop, this->T_NEAR); |
997 | |
998 | /* compute last 3 blocks of channels: |
999 | * block: | -- low -- | -- hi -- | |
1000 | * C: [c1,c2,c3,c4,c5,c6,c7,c8] |
1001 | * xc_lo|xc_hi [c1,c2,c3,c4|c5,c6,c7,c8] |
1002 | * xd_lo [c2,c3,c4,c5] |
1003 | * xd_hi >> 1 [c6,c7,c8, 0] |
1004 | * xe_lo [c3,c4,c5,c6] |
1005 | * xe_hi >> 2 [c7,c8, 0, 0] |
1006 | * (...) -- data -- | -- illegal reading -- (...) |
1007 | * ^ memory boundary |
1008 | */ |
1009 | this->movups(xc_lo, this->ptr[src_]); |
1010 | this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]); |
1011 | this->mulps(xc_lo, xc_lo); |
1012 | this->mulps(xc_hi, xc_hi); |
1013 | this->addps(xsum_lo, xc_lo); |
1014 | this->addps(xsum_hi, xc_hi); |
1015 | |
1016 | this->movups(xd_lo, this->ptr[src_ + 1 * sizeof(float)]); |
1017 | this->movups(xd_hi, this->ptr[src_ + 4 * sizeof(float)]); |
1018 | this->psrldq(xd_hi, 1 * sizeof(float)); |
1019 | this->mulps(xd_lo, xd_lo); |
1020 | this->mulps(xd_hi, xd_hi); |
1021 | this->addps(xsum_lo, xd_lo); |
1022 | this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 |
1023 | |
1024 | this->movups(xe_lo, this->ptr[src_ + 2 * sizeof(float)]); |
1025 | this->movups(xe_hi, this->ptr[src_ + 4 * sizeof(float)]); |
1026 | this->psrldq(xe_hi, 2 * sizeof(float)); |
1027 | this->mulps(xe_lo, xe_lo); |
1028 | this->mulps(xe_hi, xe_hi); |
1029 | this->addps(xsum_lo, xe_lo); |
1030 | this->addps(xsum_hi, xe_hi); |
1031 | |
1032 | this->movups(xdst_lo, xsum_lo); |
1033 | this->movups(xdst_hi, xsum_hi); |
1034 | // xdst <- xsum*xalpha_+xk_ |
1035 | this->mulps(xdst_lo, this->ptr[store_addr_]); |
1036 | this->mulps(xdst_hi, this->ptr[store_addr_]); |
1037 | this->addps(xdst_lo, this->ptr[store_addr_ + 4 * sizeof(float)]); |
1038 | this->addps(xdst_hi, this->ptr[store_addr_ + 4 * sizeof(float)]); |
1039 | |
1040 | this->movaps(xbase_lo, xdst_lo); |
1041 | this->movaps(xbase_hi, xdst_hi); |
1042 | if (pk_ != prop_kind::forward_inference) { |
1043 | this->movups(this->ptr[scratch_], xbase_lo); |
1044 | this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi); |
1045 | } |
1046 | this->mulps(xdst_lo, xdst_lo); |
1047 | this->mulps(xdst_hi, xdst_hi); |
1048 | this->mulps(xdst_lo, xbase_lo); |
1049 | this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3; |
1050 | this->sqrtps(xdst_lo, xdst_lo); |
1051 | this->sqrtps(xdst_hi, xdst_hi); |
1052 | this->sqrtps(xdst_lo, xdst_lo); |
1053 | this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75 |
1054 | this->movups(xc_lo, this->ptr[src_]); |
1055 | this->movups(xc_hi, this->ptr[src_ + 4 * sizeof(float)]); |
1056 | this->divps(xc_lo, xdst_lo); |
1057 | this->divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75 |
1058 | |
1059 | this->movups(this->ptr[dst_], xc_lo); |
1060 | this->movups(this->ptr[dst_ + 4 * sizeof(float)], xc_hi); |
1061 | |
1062 | this->postamble(); |
1063 | } |
1064 | |
1065 | template <> |
1066 | void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::nchw_body( |
1067 | int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya, |
1068 | Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye, |
1069 | Xbyak::Ymm ysum) {} |
1070 | |
1071 | template <cpu_isa_t isa, data_type_t d_type> |
1072 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body(int tail, int HW, |
1073 | prop_kind_t pk, Xbyak::Ymm ymask, Xbyak::Ymm ya, Xbyak::Ymm yb, |
1074 | Xbyak::Ymm yc, Xbyak::Ymm yd, Xbyak::Ymm ye, Xbyak::Ymm ysum) { |
1075 | const Xbyak::Ymm &ydst = this->ymm14; |
1076 | const Xbyak::Ymm &ybase = this->ymm15; |
1077 | |
1078 | this->vfmadd231ps(ysum, ye, ye); |
1079 | |
1080 | this->vmovups(ydst, ysum); |
1081 | this->vfmadd132ps(ydst, yk_, valpha_); // ydst <- ysum*valpha_+yk_ |
1082 | |
1083 | this->vmovaps(ybase, ydst); |
1084 | if (pk_ != prop_kind::forward_inference) { |
1085 | if (tail != 0) |
1086 | this->vmaskmovps(this->ptr[scratch_], ymask, ybase); |
1087 | else |
1088 | this->vmovups(this->ptr[scratch_], ybase); |
1089 | } |
1090 | this->vmulps(ydst, ydst, ydst); |
1091 | this->vmulps(ydst, ydst, ybase); // ydst = (ysum*valpha_+yk_)^3; |
1092 | this->vsqrtps(ydst, ydst); |
1093 | this->vsqrtps(ydst, ydst); // ydst = (ysum*valpha_+yk_)^0.75 |
1094 | this->vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*valpha_+yk_)^0.75 |
1095 | |
1096 | if (tail != 0) |
1097 | this->vmaskmovps(this->ptr[dst_], ymask, ydst); |
1098 | else |
1099 | this->vmovups(this->ptr[dst_], ydst); |
1100 | |
1101 | this->vfnmadd231ps(ysum, ya, ya); |
1102 | this->vmovups(ya, yb); |
1103 | this->vmovups(yb, yc); |
1104 | this->vmovups(yc, yd); |
1105 | this->vmovups(yd, ye); |
1106 | } |
1107 | |
1108 | template <cpu_isa_t isa, data_type_t d_type> |
1109 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_tail_sse41(int tail, |
1110 | Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) {} |
1111 | |
1112 | template <> |
1113 | void jit_uni_lrn_fwd_kernel_t<sse41, |
1114 | dnnl::impl::data_type::f32>::nchw_tail_sse41(int tail, |
1115 | Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) { |
1116 | Xbyak::Xmm xmm_tmp = xmm10; |
1117 | this->movaps(xmm_tmp, xtail_hi); |
1118 | |
1119 | if (tail > 3) { |
1120 | /* Store upper-half directly */ |
1121 | this->movups(this->ptr[reg_dst + (tail - 4) * sizeof(float)], xtail_hi); |
1122 | this->movaps(xmm_tmp, xtail_lo); |
1123 | tail -= 4; |
1124 | } |
1125 | if (tail > 0) { |
1126 | /* Store on a single-element basis when 'tail' overlaps |
1127 | * with 'src_' */ |
1128 | this->psrldq(xmm_tmp, (4 - tail) * sizeof(float)); |
1129 | this->movss(this->ptr[reg_dst], xmm_tmp); |
1130 | |
1131 | for (int i = 1; i < tail; i++) { |
1132 | this->psrldq(xmm_tmp, sizeof(float)); |
1133 | this->movss(this->ptr[reg_dst + i * sizeof(float)], xmm_tmp); |
1134 | } |
1135 | } |
1136 | } |
1137 | |
1138 | template <> |
1139 | void jit_uni_lrn_fwd_kernel_t<sse41, |
1140 | dnnl::impl::data_type::f32>::nchw_body_sse41(int tail, int HW, |
1141 | prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, |
1142 | Xbyak::Xmm xsum_hi) { |
1143 | const Xbyak::Xmm &xdst_lo = this->xmm0; |
1144 | const Xbyak::Xmm &xdst_hi = this->xmm1; |
1145 | const Xbyak::Xmm &xbase_lo = this->xmm6; |
1146 | const Xbyak::Xmm &xbase_hi = this->xmm7; |
1147 | const Xbyak::Xmm &xtmp_lo = this->xmm8; |
1148 | const Xbyak::Xmm &xtmp_hi = this->xmm9; |
1149 | const Xbyak::Xmm &xa_lo = this->xmm6; |
1150 | const Xbyak::Xmm &xa_hi = this->xmm7; |
1151 | const Xbyak::Xmm &xb_lo = this->xmm8; |
1152 | const Xbyak::Xmm &xb_hi = this->xmm9; |
1153 | const Xbyak::Xmm &xc_lo = this->xmm10; |
1154 | const Xbyak::Xmm &xc_hi = this->xmm11; |
1155 | const Xbyak::Xmm &xd_lo = this->xmm12; |
1156 | const Xbyak::Xmm &xd_hi = this->xmm13; |
1157 | |
1158 | // store xe |
1159 | this->movaps(this->ptr[store_addr_ + 10 * 4 * sizeof(float)], xe_lo); |
1160 | this->movaps(this->ptr[store_addr_ + 11 * 4 * sizeof(float)], xe_hi); |
1161 | |
1162 | this->mulps(xe_lo, xe_lo); |
1163 | this->mulps(xe_hi, xe_hi); |
1164 | this->addps(xsum_lo, xe_lo); |
1165 | this->addps(xsum_hi, xe_hi); |
1166 | |
1167 | // xdst <- xsum*xalpha_+xk_ |
1168 | this->movaps(xdst_lo, xsum_lo); |
1169 | this->movaps(xdst_hi, xsum_hi); |
1170 | this->mulps(xdst_lo, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]); |
1171 | this->mulps(xdst_hi, this->ptr[store_addr_ + 0 * 4 * sizeof(float)]); |
1172 | this->addps(xdst_lo, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]); |
1173 | this->addps(xdst_hi, this->ptr[store_addr_ + 1 * 4 * sizeof(float)]); |
1174 | |
1175 | this->movaps(xbase_lo, xdst_lo); |
1176 | this->movaps(xbase_hi, xdst_hi); |
1177 | if (pk_ != prop_kind::forward_inference) { |
1178 | if (tail != 0) { |
1179 | nchw_tail_sse41(tail, scratch_, xbase_lo, xbase_hi); |
1180 | } else { |
1181 | this->movups(this->ptr[scratch_], xbase_lo); |
1182 | this->movups(this->ptr[scratch_ + 4 * sizeof(float)], xbase_hi); |
1183 | } |
1184 | } |
1185 | this->mulps(xdst_lo, xdst_lo); |
1186 | this->mulps(xdst_hi, xdst_hi); |
1187 | this->mulps(xdst_lo, xbase_lo); |
1188 | this->mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha_+xk_)^3; |
1189 | this->sqrtps(xdst_lo, xdst_lo); |
1190 | this->sqrtps(xdst_hi, xdst_hi); |
1191 | this->sqrtps(xdst_lo, xdst_lo); |
1192 | this->sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha_+xk_)^0.75 |
1193 | this->movaps(xtmp_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]); |
1194 | this->movaps(xtmp_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]); |
1195 | this->divps(xtmp_lo, xdst_lo); |
1196 | this->divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha_+xk_)^0.75 |
1197 | this->movaps(xdst_lo, xtmp_lo); |
1198 | this->movaps(xdst_hi, xtmp_hi); |
1199 | |
1200 | if (tail != 0) { |
1201 | nchw_tail_sse41(tail, dst_, xdst_lo, xdst_hi); |
1202 | } else { |
1203 | this->movups(this->ptr[dst_], xdst_lo); |
1204 | this->movups(this->ptr[dst_ + 4 * sizeof(float)], xdst_hi); |
1205 | } |
1206 | |
1207 | this->movaps(xa_lo, this->ptr[store_addr_ + 2 * 4 * sizeof(float)]); |
1208 | this->movaps(xa_hi, this->ptr[store_addr_ + 3 * 4 * sizeof(float)]); |
1209 | this->mulps(xa_lo, xa_lo); |
1210 | this->mulps(xa_hi, xa_hi); |
1211 | this->subps(xsum_lo, xa_lo); |
1212 | this->subps(xsum_hi, xa_hi); |
1213 | |
1214 | // xa <- xb |
1215 | this->movaps(xb_lo, this->ptr[store_addr_ + 4 * 4 * sizeof(float)]); |
1216 | this->movaps(xb_hi, this->ptr[store_addr_ + 5 * 4 * sizeof(float)]); |
1217 | this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xb_lo); |
1218 | this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xb_hi); |
1219 | |
1220 | // xb <- xc |
1221 | this->movaps(xc_lo, this->ptr[store_addr_ + 6 * 4 * sizeof(float)]); |
1222 | this->movaps(xc_hi, this->ptr[store_addr_ + 7 * 4 * sizeof(float)]); |
1223 | this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xc_lo); |
1224 | this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xc_hi); |
1225 | |
1226 | // xc <- xd |
1227 | this->movaps(xd_lo, this->ptr[store_addr_ + 8 * 4 * sizeof(float)]); |
1228 | this->movaps(xd_hi, this->ptr[store_addr_ + 9 * 4 * sizeof(float)]); |
1229 | this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xd_lo); |
1230 | this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xd_hi); |
1231 | |
1232 | // xd <- xe |
1233 | this->movaps(xe_lo, this->ptr[store_addr_ + 10 * 4 * sizeof(float)]); |
1234 | this->movaps(xe_hi, this->ptr[store_addr_ + 11 * 4 * sizeof(float)]); |
1235 | this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xe_lo); |
1236 | this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xe_hi); |
1237 | } |
1238 | |
1239 | template <cpu_isa_t isa, data_type_t d_type> |
1240 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::nchw_body_sse41(int tail, int HW, |
1241 | prop_kind_t pk, Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, |
1242 | Xbyak::Xmm xsum_hi) {} |
1243 | |
1244 | template <cpu_isa_t isa, data_type_t d_type> |
1245 | jit_uni_lrn_fwd_kernel_t<isa, d_type>::jit_uni_lrn_fwd_kernel_t( |
1246 | const nchw_across_t &J, float A, float K, prop_kind_t pk, |
1247 | void *code_ptr, size_t code_size) |
1248 | : Base(code_ptr, code_size, jit_name()) |
1249 | , config_(lrn_config_t::nchw_across) |
1250 | , nchw_across_(J) |
1251 | , alpha_(A) |
1252 | , k_(K) |
1253 | , pk_(pk) {} |
1254 | |
1255 | template <cpu_isa_t isa, data_type_t d_type> |
1256 | void jit_uni_lrn_fwd_kernel_t<isa, d_type>::generate(const nchw_across_t &J) { |
1257 | static const uint32_t mask[] |
1258 | = {0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, |
1259 | 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0}; |
1260 | const Xbyak::Reg64 &c = this->r10; |
1261 | const Xbyak::Ymm &ymask = this->ymm2; |
1262 | const Xbyak::Ymm &ye = this->ymm3; |
1263 | const Xbyak::Ymm &ya = this->ymm4; |
1264 | const Xbyak::Ymm &yb = this->ymm5; |
1265 | const Xbyak::Ymm &yc = this->ymm6; |
1266 | const Xbyak::Ymm &yd = this->ymm7; |
1267 | const Xbyak::Ymm &ysum = this->ymm8; |
1268 | |
1269 | this->preamble(); |
1270 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
1271 | |
1272 | if (J.tail != 0) { |
1273 | this->mov( |
1274 | this->imm_addr64_, reinterpret_cast<size_t>(&mask[7 - J.tail])); |
1275 | this->vmovups(ymask, this->ptr[this->imm_addr64_]); |
1276 | } |
1277 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
1278 | this->vmovq(xalpha_, this->imm_addr64_); |
1279 | this->vbroadcastss(valpha_, xalpha_); |
1280 | |
1281 | this->mov(this->imm_addr64_, float2int(this->k_)); |
1282 | this->vmovq(xk_, this->imm_addr64_); |
1283 | this->vbroadcastss(yk_, xk_); |
1284 | |
1285 | this->mov(src_, this->ptr[this->param1 + 0]); |
1286 | this->mov(dst_, this->ptr[this->param1 + 8]); |
1287 | if (pk_ != prop_kind::forward_inference) |
1288 | this->mov(scratch_, this->ptr[this->param1 + 16]); |
1289 | |
1290 | this->vxorps(ya, ya, ya); |
1291 | this->vxorps(yb, yb, yb); |
1292 | if (J.tail != 0) |
1293 | this->vmaskmovps(yc, ymask, this->ptr[src_ + J.HW * 0]); |
1294 | else |
1295 | this->vmovups(yc, this->ptr[src_ + J.HW * 0]); |
1296 | if (J.tail != 0) |
1297 | this->vmaskmovps(yd, ymask, this->ptr[src_ + J.HW * 4]); |
1298 | else |
1299 | this->vmovups(yd, this->ptr[src_ + J.HW * 4]); |
1300 | |
1301 | this->vxorps(ysum, ysum, ysum); |
1302 | this->vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 |
1303 | this->vfmadd231ps(ysum, yd, yd); |
1304 | |
1305 | this->mov(c, J.C - 2); |
1306 | Label lrn_loop; |
1307 | this->L(lrn_loop); |
1308 | |
1309 | if (J.tail != 0) |
1310 | this->vmaskmovps(ye, ymask, this->ptr[src_ + J.HW * 8]); |
1311 | else |
1312 | this->vmovups(ye, this->ptr[src_ + J.HW * 8]); |
1313 | |
1314 | nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum); |
1315 | |
1316 | this->add(src_, J.HW * 4); |
1317 | this->add(dst_, J.HW * 4); |
1318 | if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4); |
1319 | this->dec(c); |
1320 | this->cmp(c, 0); |
1321 | this->jne(lrn_loop, this->T_NEAR); |
1322 | |
1323 | this->vxorps(ye, ye, ye); |
1324 | |
1325 | nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum); |
1326 | this->add(src_, J.HW * 4); |
1327 | this->add(dst_, J.HW * 4); |
1328 | if (pk_ != prop_kind::forward_inference) this->add(scratch_, J.HW * 4); |
1329 | |
1330 | nchw_body(J.tail, J.HW, pk_, ymask, ya, yb, yc, yd, ye, ysum); |
1331 | |
1332 | this->postamble(); |
1333 | } |
1334 | |
1335 | template <cpu_isa_t isa, data_type_t d_type> |
1336 | jit_uni_lrn_fwd_kernel_t<isa, d_type>::~jit_uni_lrn_fwd_kernel_t() = default; |
1337 | |
1338 | template <> |
1339 | jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>:: |
1340 | jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K, |
1341 | prop_kind_t pk, void *code_ptr, size_t code_size) |
1342 | : Base(code_ptr, code_size, jit_name()) |
1343 | , config_(lrn_config_t::nchw_across) |
1344 | , nchw_across_(J) |
1345 | , alpha_(A) |
1346 | , k_(K) |
1347 | , pk_(pk) {} |
1348 | |
1349 | template <> |
1350 | void jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>::generate( |
1351 | const nchw_across_t &J) { |
1352 | |
1353 | /* Load from within the memory boundary of 'src_' and apply a zero-mask to |
1354 | * the 'x_hi' register: |
1355 | * block: src_ |tail = 3 |
1356 | * src_: [x,x,x,x|a,b,c] |
1357 | * x_hi: [x,a,b,c] |
1358 | * mask: [0,1,1,1] |
1359 | * (...) -- data -- | -- illegal reading -- (...) |
1360 | * ^ memory boundary |
1361 | * |
1362 | * 'x_lo' is loaded with the elements between 'src_' and 'x_hi' when |
1363 | * tail.size is between [5:7]. The register is then left-shifted to |
1364 | * clear the overlapping elements with 'x_hi'. |
1365 | * block: - src_ - | tail = 7 |
1366 | * src_: (...) [x,|a,b,c,d,e,f,g] |
1367 | * x_hi [d,e,f,g] |
1368 | * x_lo [a,b,c,d] |
1369 | * x_lo >> 1: [0,a,b,c] |
1370 | * (...) -- data -- | -- illegal reading -- (...) |
1371 | * ^ memory boundary |
1372 | * |
1373 | * - seg-fault happens if read occurs anywhere outside the |
1374 | * memory boundary. |
1375 | * */ |
1376 | static const uint32_t mask[] |
1377 | = {0, 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}; |
1378 | assert(J.HW > 3); |
1379 | |
1380 | const Xbyak::Reg64 &c = r10; |
1381 | |
1382 | // unused: xmm2 |
1383 | const Xbyak::Xmm &xmask_hi = this->xmm3; |
1384 | const Xbyak::Xmm &xsum_lo = this->xmm4; |
1385 | const Xbyak::Xmm &xsum_hi = this->xmm5; |
1386 | const Xbyak::Xmm &xa_lo = this->xmm6; |
1387 | const Xbyak::Xmm &xa_hi = this->xmm7; |
1388 | const Xbyak::Xmm &xb_lo = this->xmm8; |
1389 | const Xbyak::Xmm &xb_hi = this->xmm9; |
1390 | const Xbyak::Xmm &xc_lo = this->xmm10; |
1391 | const Xbyak::Xmm &xc_hi = this->xmm11; |
1392 | const Xbyak::Xmm &xd_lo = this->xmm12; |
1393 | const Xbyak::Xmm &xd_hi = this->xmm13; |
1394 | const Xbyak::Xmm &xe_lo = this->xmm14; |
1395 | const Xbyak::Xmm &xe_hi = this->xmm15; |
1396 | |
1397 | const int vlen = cpu_isa_traits<sse41>::vlen / sizeof(float); |
1398 | |
1399 | bool compute_tail = J.tail != 0; |
1400 | bool load_lo = J.tail == 0 || J.tail > 4; |
1401 | |
1402 | size_t h_offset = vlen; |
1403 | size_t l_shift = 0; |
1404 | |
1405 | this->preamble(); |
1406 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
1407 | |
1408 | this->mov(src_, this->ptr[this->param1 + 0]); |
1409 | this->mov(dst_, this->ptr[this->param1 + 8]); |
1410 | if (pk_ != prop_kind::forward_inference) |
1411 | this->mov(scratch_, this->ptr[this->param1 + 16]); |
1412 | |
1413 | this->sub(rsp, stack_space_needed_); |
1414 | this->mov(store_addr_, rsp); |
1415 | this->and_(store_addr_, -15); |
1416 | |
1417 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
1418 | this->movq(xalpha_, this->imm_addr64_); |
1419 | this->shufps(xalpha_, xalpha_, 0); |
1420 | |
1421 | this->mov(this->imm_addr64_, float2int(this->k_)); |
1422 | this->movq(xk_, this->imm_addr64_); |
1423 | this->shufps(xk_, xk_, 0); |
1424 | |
1425 | // put alpha_ and k_ into store (free up regs) |
1426 | this->movaps(this->ptr[store_addr_ + 0 * 4 * sizeof(float)], xalpha_); |
1427 | this->movaps(this->ptr[store_addr_ + 1 * 4 * sizeof(float)], xk_); |
1428 | |
1429 | if (compute_tail) { |
1430 | assert(J.tail > 0 && J.tail < 2 * vlen); |
1431 | h_offset = J.tail - vlen; |
1432 | l_shift = nstl::min(2 * vlen - J.tail, vlen); |
1433 | |
1434 | /* if 'tail' is between [1:3], need to zero-mask for underflow */ |
1435 | size_t m_off = nstl::min(J.tail - 1, 3); |
1436 | this->mov(this->imm_addr64_, reinterpret_cast<size_t>(&mask[m_off])); |
1437 | this->movups(xmask_hi, this->ptr[this->imm_addr64_]); |
1438 | } |
1439 | // init xa, xb |
1440 | this->xorps(xa_lo, xa_lo); |
1441 | this->xorps(xa_hi, xa_hi); |
1442 | this->xorps(xb_lo, xb_lo); |
1443 | this->xorps(xb_hi, xb_hi); |
1444 | |
1445 | // read xc, xd |
1446 | if (load_lo) this->movups(xc_lo, this->ptr[src_ + J.HW * 0]); |
1447 | this->movups(xc_hi, this->ptr[src_ + J.HW * 0 + h_offset * sizeof(float)]); |
1448 | if (compute_tail) { |
1449 | this->pslldq(xc_lo, l_shift * sizeof(float)); |
1450 | this->andps(xc_hi, xmask_hi); |
1451 | } |
1452 | |
1453 | if (load_lo) this->movups(xd_lo, this->ptr[src_ + J.HW * 4]); |
1454 | this->movups(xd_hi, this->ptr[src_ + J.HW * 4 + h_offset * sizeof(float)]); |
1455 | if (compute_tail) { |
1456 | this->pslldq(xd_lo, l_shift * sizeof(float)); |
1457 | this->andps(xd_hi, xmask_hi); |
1458 | } |
1459 | |
1460 | // put xa, xb, xc, xd into store to free-up regs |
1461 | this->movaps(this->ptr[store_addr_ + 2 * 4 * sizeof(float)], xa_lo); |
1462 | this->movaps(this->ptr[store_addr_ + 3 * 4 * sizeof(float)], xa_hi); |
1463 | this->movaps(this->ptr[store_addr_ + 4 * 4 * sizeof(float)], xb_lo); |
1464 | this->movaps(this->ptr[store_addr_ + 5 * 4 * sizeof(float)], xb_hi); |
1465 | this->movaps(this->ptr[store_addr_ + 6 * 4 * sizeof(float)], xc_lo); |
1466 | this->movaps(this->ptr[store_addr_ + 7 * 4 * sizeof(float)], xc_hi); |
1467 | this->movaps(this->ptr[store_addr_ + 8 * 4 * sizeof(float)], xd_lo); |
1468 | this->movaps(this->ptr[store_addr_ + 9 * 4 * sizeof(float)], xd_hi); |
1469 | |
1470 | this->xorps(xsum_lo, xsum_lo); |
1471 | this->xorps(xsum_hi, xsum_hi); |
1472 | this->mulps(xc_lo, xc_lo); |
1473 | this->mulps(xc_hi, xc_hi); |
1474 | this->addps(xsum_lo, xc_lo); |
1475 | this->addps(xsum_hi, xc_hi); |
1476 | this->mulps(xd_lo, xd_lo); |
1477 | this->mulps(xd_hi, xd_hi); |
1478 | this->addps(xsum_lo, xd_lo); |
1479 | this->addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 |
1480 | |
1481 | this->mov(c, J.C - 2); |
1482 | Label lrn_loop; |
1483 | this->L(lrn_loop); |
1484 | |
1485 | if (load_lo) this->movups(xe_lo, this->ptr[src_ + J.HW * 8]); |
1486 | this->movups(xe_hi, this->ptr[src_ + J.HW * 8 + h_offset * sizeof(float)]); |
1487 | if (compute_tail) { |
1488 | this->pslldq(xe_lo, l_shift * sizeof(float)); |
1489 | this->andps(xe_hi, xmask_hi); |
1490 | } |
1491 | |
1492 | nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi); |
1493 | |
1494 | this->add(src_, J.HW * 4); |
1495 | this->add(dst_, J.HW * 4); |
1496 | if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4); |
1497 | this->dec(c); |
1498 | this->cmp(c, 0); |
1499 | this->jne(lrn_loop, this->T_NEAR); |
1500 | |
1501 | this->xorps(xe_lo, xe_lo); |
1502 | this->xorps(xe_hi, xe_hi); |
1503 | |
1504 | nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi); |
1505 | this->add(src_, J.HW * 4); |
1506 | this->add(dst_, J.HW * 4); |
1507 | if (pk_ != prop_kind::forward_inference) add(scratch_, J.HW * 4); |
1508 | |
1509 | nchw_body_sse41(J.tail, J.HW, pk_, xe_lo, xe_hi, xsum_lo, xsum_hi); |
1510 | |
1511 | this->add(rsp, stack_space_needed_); |
1512 | |
1513 | this->postamble(); |
1514 | } |
1515 | |
1516 | ////////////////////////////////////////////////////////////////////////////// |
1517 | // backward kernel |
1518 | template <cpu_isa_t isa, data_type_t d_type> |
1519 | jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t( |
1520 | const nchw8c_across_t &J, float A, float B, int use_h_parallel, |
1521 | void *code_ptr, size_t code_size) |
1522 | : Base(code_ptr, code_size, jit_name()) |
1523 | , config_(lrn_config_t::nchw8c_across) |
1524 | , nchw8c_across_(J) |
1525 | , nalphabeta_(-2 * A * B) |
1526 | , use_h_parallelizm_(use_h_parallel) {} |
1527 | |
1528 | template <cpu_isa_t isa, data_type_t d_type> |
1529 | void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate(const nchw8c_across_t &J) { |
1530 | |
1531 | const Xbyak::Reg64 &t = this->rsp; |
1532 | const Xbyak::Reg64 &hw = this->r10; |
1533 | const Xbyak::Xmm &xsrc_prev = this->xmm1; |
1534 | const Xbyak::Xmm &xws_prev = this->xmm2; |
1535 | const Xbyak::Xmm &xdiffdst_prev = this->xmm3; |
1536 | const Xbyak::Ymm &ysrc = this->ymm4; |
1537 | const Xbyak::Ymm &yws = this->ymm5; |
1538 | const Xbyak::Ymm &ydiffdst = this->ymm6; |
1539 | const Xbyak::Xmm &xsrc_next = this->xmm7; |
1540 | const Xbyak::Xmm &xws_next = this->xmm8; |
1541 | const Xbyak::Xmm &xdiffdst_next = this->xmm9; |
1542 | const Xbyak::Ymm &ya = this->ymm10; |
1543 | const Xbyak::Xmm &xa = this->xmm10; |
1544 | const Xbyak::Ymm &yb = this->ymm11; |
1545 | const Xbyak::Ymm &yd = this->ymm12; |
1546 | const Xbyak::Ymm &ye = this->ymm13; |
1547 | const Xbyak::Ymm &ysum = this->ymm14; |
1548 | const Xbyak::Ymm &ydiffsrc = this->ymm15; |
1549 | |
1550 | this->preamble(); |
1551 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
1552 | |
1553 | #define GET_OFF(field) offsetof(jit_args_bwd_t, field) |
1554 | this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]); |
1555 | this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]); |
1556 | this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]); |
1557 | this->mov(bwd_intermediate_res_, |
1558 | this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]); |
1559 | this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]); |
1560 | #undef GET_OFF |
1561 | |
1562 | this->sub(t, 64); |
1563 | this->mov(this->imm_addr64_, float2int(this->nalphabeta_)); |
1564 | this->vmovq(xnalphabeta_, this->imm_addr64_); |
1565 | this->vbroadcastss(vnalphabeta_, xnalphabeta_); |
1566 | |
1567 | bool is_single = J.version == 3; |
1568 | bool is_first = J.version == -1 || J.version == -2; |
1569 | bool is_last = J.version == +1 || J.version == -2; |
1570 | |
1571 | if (is_first || is_single) { |
1572 | this->vxorps(xsrc_prev, xsrc_prev, xsrc_prev); |
1573 | this->vmovups(this->ptr[t + 0], xsrc_prev); |
1574 | } |
1575 | if (is_last || is_single) { |
1576 | this->vxorps(xsrc_next, xsrc_next, xsrc_next); |
1577 | this->vmovups(this->ptr[t + 48], xsrc_next); |
1578 | } |
1579 | this->mov(hw, this->use_h_parallelizm_ ? J.W : J.H * J.W); |
1580 | Label lrn_loop; |
1581 | this->L(lrn_loop); |
1582 | { |
1583 | if (!is_first && !is_single) { |
1584 | this->vmovups(xws_prev, this->ptr[scratch_ - J.H * J.W * 32 + 16]); |
1585 | this->vmovups(xsrc_prev, this->ptr[src_ - J.H * J.W * 32 + 16]); |
1586 | this->vmovups( |
1587 | xdiffdst_prev, this->ptr[diffdst_ - J.H * J.W * 32 + 16]); |
1588 | this->vmulps(xa, xws_prev, xws_prev); |
1589 | this->vmulps(xa, xa, xws_prev); |
1590 | this->vsqrtps(xa, xa); |
1591 | this->vsqrtps(xa, xa); |
1592 | this->vmulps(xa, xa, xws_prev); |
1593 | this->vdivps(xsrc_prev, xsrc_prev, xa); |
1594 | this->vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); |
1595 | } |
1596 | |
1597 | this->vmovups(ysrc, this->ptr[src_]); |
1598 | this->vmovups(yws, this->ptr[scratch_]); |
1599 | this->vmovups(ydiffdst, this->ptr[diffdst_]); |
1600 | this->vmulps(ya, yws, yws); |
1601 | this->vmulps(ya, ya, yws); |
1602 | this->vsqrtps(ya, ya); |
1603 | this->vsqrtps(ya, ya); |
1604 | this->vdivps(ydiffsrc, ydiffdst, ya); |
1605 | this->vdivps(ysum, ydiffsrc, yws); |
1606 | this->vmulps(ysum, ysum, ysrc); |
1607 | |
1608 | if (!is_last && !is_single) { |
1609 | this->vmovups(xws_next, this->ptr[scratch_ + J.H * J.W * 32]); |
1610 | this->vmovups(xsrc_next, this->ptr[src_ + J.H * J.W * 32]); |
1611 | this->vmovups(xdiffdst_next, this->ptr[diffdst_ + J.H * J.W * 32]); |
1612 | this->vmulps(xa, xws_next, xws_next); |
1613 | this->vmulps(xa, xa, xws_next); |
1614 | this->vsqrtps(xa, xa); |
1615 | this->vsqrtps(xa, xa); |
1616 | this->vmulps(xa, xa, xws_next); |
1617 | this->vdivps(xsrc_next, xsrc_next, xa); |
1618 | this->vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); |
1619 | } |
1620 | |
1621 | if (!is_first && !is_single) |
1622 | this->vmovups(this->ptr[t + 0], xdiffdst_prev); |
1623 | this->vmovups(this->ptr[t + 16], ysum); |
1624 | if (!is_last && !is_single) |
1625 | this->vmovups(this->ptr[t + 48], xdiffdst_next); |
1626 | |
1627 | this->vmovups(ya, this->ptr[t + 16 - 8]); |
1628 | this->vmovups(yb, this->ptr[t + 16 - 4]); |
1629 | this->vaddps(ysum, ysum, ya); |
1630 | this->vmulps(ysrc, ysrc, vnalphabeta_); |
1631 | this->vaddps(ysum, ysum, yb); |
1632 | |
1633 | this->vmovups(yd, this->ptr[t + 16 + 4]); |
1634 | this->vmovups(ye, this->ptr[t + 16 + 8]); |
1635 | this->vaddps(ysum, ysum, yd); |
1636 | this->vaddps(ysum, ysum, ye); |
1637 | |
1638 | this->vfmadd231ps(ydiffsrc, ysum, ysrc); |
1639 | |
1640 | this->vmovups(this->ptr[diffsrc_], ydiffsrc); |
1641 | |
1642 | this->add(src_, 32); |
1643 | this->add(diffsrc_, 32); |
1644 | this->add(diffdst_, 32); |
1645 | this->add(scratch_, 32); |
1646 | |
1647 | this->dec(hw); |
1648 | this->cmp(hw, 0); |
1649 | this->jne(lrn_loop, this->T_NEAR); |
1650 | } |
1651 | |
1652 | this->add(t, 64); |
1653 | this->postamble(); |
1654 | } |
1655 | |
1656 | template <cpu_isa_t isa, data_type_t d_type> |
1657 | jit_uni_lrn_bwd_kernel_t<isa, d_type>::jit_uni_lrn_bwd_kernel_t( |
1658 | const within_config_t &config, float A, float B, void *code_ptr, |
1659 | size_t code_size) |
1660 | : Base(config, code_ptr, code_size, jit_name()) |
1661 | , config_(lrn_config_t::within_config) |
1662 | , within_config_(config) |
1663 | , nalphabeta_(-2.0f * A * B) {} |
1664 | |
1665 | template <cpu_isa_t isa, data_type_t d_type> |
1666 | void jit_uni_lrn_bwd_kernel_t<isa, d_type>::generate( |
1667 | const within_config_t &config) { |
1668 | |
1669 | this->preamble(); |
1670 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
1671 | |
1672 | #define GET_OFF(field) offsetof(jit_args_bwd_t, field) |
1673 | this->mov(src_, this->ptr[this->param1 + GET_OFF(src)]); |
1674 | this->mov(diffdst_, this->ptr[this->param1 + GET_OFF(diff_dst)]); |
1675 | this->mov(scratch_, this->ptr[this->param1 + GET_OFF(scratch)]); |
1676 | this->mov(bwd_intermediate_res_, |
1677 | this->ptr[this->param1 + GET_OFF(bwd_intermediate_res)]); |
1678 | this->mov(diffsrc_, this->ptr[this->param1 + GET_OFF(diff_src)]); |
1679 | #undef GET_OFF |
1680 | this->load_constant(nalphabeta_, vnalphabeta_, xnalphabeta_); |
1681 | |
1682 | static const int max_reg_blocks = is_superset(isa, avx512_core) ? 3 : 1; |
1683 | this->within_loop(config, max_reg_blocks, prop_kind::backward); |
1684 | |
1685 | this->postamble(); |
1686 | } |
1687 | |
1688 | template <cpu_isa_t isa, data_type_t d_type> |
1689 | void jit_uni_lrn_bwd_kernel_t<isa, d_type>::within_body(int hoff, int Hoff, |
1690 | int woff, int Woff, int stride, prop_kind_t pk, const int reg_block, |
1691 | int pixel_offset) { |
1692 | |
1693 | static const std::array<Vmm, 3> vsum {{Vmm(1), Vmm(9), Vmm(18)}}; |
1694 | static const std::array<std::array<Vmm, 3>, 3> diff_dst {{ |
1695 | {{Vmm(2), Vmm(3), Vmm(6)}}, |
1696 | {{Vmm(10), Vmm(11), Vmm(23)}}, |
1697 | {{Vmm(19), Vmm(20), Vmm(26)}}, |
1698 | }}; |
1699 | static const std::array<std::array<Vmm, 3>, 3> ws1 {{ |
1700 | {{Vmm(4), Vmm(5), Vmm(15)}}, |
1701 | {{Vmm(12), Vmm(13), Vmm(27)}}, |
1702 | {{Vmm(21), Vmm(22), Vmm(28)}}, |
1703 | }}; |
1704 | static const std::array<Vmm, 3> ws0 = !this->emulate_bfloat_ |
1705 | ? std::array<Vmm, 3> {{Vmm(29), Vmm(30), Vmm(31)}} |
1706 | : std::array<Vmm, 3> {{Vmm(6), Vmm(15), Vmm(23)}}; |
1707 | static const std::array<Vmm, 3> src {{Vmm(7), Vmm(16), Vmm(24)}}; |
1708 | static const std::array<Vmm, 3> a {{Vmm(8), Vmm(17), Vmm(25)}}; |
1709 | |
1710 | static const std::size_t used_tmp_regs |
1711 | = this->emulate_bfloat_ ? ws1[0].size() - 1 : ws1[0].size(); |
1712 | |
1713 | IRB_LOOP(this->uni_vxorps(vsum[irb], vsum[irb], vsum[irb])); |
1714 | for (int i = hoff; i <= Hoff; ++i) { |
1715 | for (int j = woff; j <= Woff; ++j) { |
1716 | const auto idx = this->tempIdx_ % used_tmp_regs; |
1717 | IRB_LOOP(this->load_data(diff_dst[irb][idx], |
1718 | this->ptr[(diffdst_ + pixel_offset + irb_off) |
1719 | + (i * stride + j) * this->single_pixel_offset_])); |
1720 | IRB_LOOP(this->load_data(ws1[irb][idx], |
1721 | this->ptr[(bwd_intermediate_res_ + pixel_offset + irb_off) |
1722 | + (i * stride + j) * this->single_pixel_offset_])); |
1723 | |
1724 | if (i == 0 && j == 0) { |
1725 | if (utils::one_of(d_type, data_type::bf16, data_type::f16)) { |
1726 | IRB_LOOP(this->load_data(ws0[irb], |
1727 | this->ptr[(scratch_ + pixel_offset + irb_off)])); |
1728 | IRB_LOOP( |
1729 | this->vdivps(a[irb], diff_dst[irb][idx], ws0[irb])); |
1730 | } else { |
1731 | IRB_LOOP(this->vdivps(a[irb], diff_dst[irb][idx], |
1732 | this->ptr[(scratch_ + pixel_offset + irb_off)])); |
1733 | } |
1734 | } |
1735 | |
1736 | IRB_LOOP(this->vfmadd231ps( |
1737 | vsum[irb], ws1[irb][idx], diff_dst[irb][idx])); |
1738 | ++(this->tempIdx_); |
1739 | } |
1740 | } |
1741 | |
1742 | this->tempIdx_ = this->tempIdx_ % used_tmp_regs; |
1743 | |
1744 | if (utils::one_of(d_type, data_type::bf16, data_type::f16)) { |
1745 | IRB_LOOP(this->load_data( |
1746 | src[irb], this->ptr[(src_ + pixel_offset + irb_off)])); |
1747 | IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_, src[irb])); |
1748 | } else { |
1749 | IRB_LOOP(this->vmulps(src[irb], this->vnalphabeta_, |
1750 | this->ptr[(src_ + pixel_offset + irb_off)])); |
1751 | } |
1752 | |
1753 | IRB_LOOP(this->vfmadd231ps(a[irb], src[irb], vsum[irb])); |
1754 | |
1755 | IRB_LOOP(this->store_data( |
1756 | this->ptr[diffsrc_ + pixel_offset + irb_off], a[irb])); |
1757 | |
1758 | if (is_superset(isa, avx512_core)) |
1759 | this->reg_block_idx_ = (this->reg_block_idx_ % vsum.size()) + 1; |
1760 | } |
1761 | |
1762 | template <cpu_isa_t isa, data_type_t d_type> |
1763 | void jit_uni_lrn_bwd_kernel_t<isa, d_type>::move_data_pointers( |
1764 | int pixel_count, prop_kind_t pk) { |
1765 | const int pixel_offset = this->single_pixel_offset_ * pixel_count; |
1766 | this->add(src_, pixel_offset); |
1767 | this->add(diffsrc_, pixel_offset); |
1768 | this->add(diffdst_, pixel_offset); |
1769 | this->add(scratch_, pixel_offset); |
1770 | this->add(bwd_intermediate_res_, pixel_offset); |
1771 | } |
1772 | |
1773 | template class jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>; |
1774 | template class jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>; |
1775 | template class jit_uni_lrn_fwd_kernel_t<avx512_core, |
1776 | dnnl::impl::data_type::f32>; |
1777 | template class jit_uni_lrn_fwd_kernel_t<avx512_core, |
1778 | dnnl::impl::data_type::bf16>; |
1779 | template class jit_uni_lrn_fwd_kernel_t<avx512_core_fp16, |
1780 | dnnl::impl::data_type::f16>; |
1781 | |
1782 | template class jit_uni_lrn_kernel_t< |
1783 | jit_uni_lrn_fwd_kernel_t<sse41, dnnl::impl::data_type::f32>>; |
1784 | template class jit_uni_lrn_kernel_t< |
1785 | jit_uni_lrn_fwd_kernel_t<avx2, dnnl::impl::data_type::f32>>; |
1786 | template class jit_uni_lrn_kernel_t< |
1787 | jit_uni_lrn_fwd_kernel_t<avx512_core, dnnl::impl::data_type::f32>>; |
1788 | template class jit_uni_lrn_kernel_t< |
1789 | jit_uni_lrn_fwd_kernel_t<avx512_core, dnnl::impl::data_type::bf16>>; |
1790 | template class jit_uni_lrn_kernel_t< |
1791 | jit_uni_lrn_fwd_kernel_t<avx512_core_fp16, dnnl::impl::data_type::f16>>; |
1792 | |
1793 | template class jit_uni_lrn_bwd_kernel_t<avx512_core_fp16, |
1794 | dnnl::impl::data_type::f16>; |
1795 | template class jit_uni_lrn_bwd_kernel_t<avx512_core, |
1796 | dnnl::impl::data_type::f32>; |
1797 | template class jit_uni_lrn_bwd_kernel_t<avx512_core, |
1798 | dnnl::impl::data_type::bf16>; |
1799 | template class jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>; |
1800 | |
1801 | template class jit_uni_lrn_kernel_t< |
1802 | jit_uni_lrn_bwd_kernel_t<avx2, dnnl::impl::data_type::f32>>; |
1803 | template class jit_uni_lrn_kernel_t< |
1804 | jit_uni_lrn_bwd_kernel_t<avx512_core, dnnl::impl::data_type::f32>>; |
1805 | template class jit_uni_lrn_kernel_t< |
1806 | jit_uni_lrn_bwd_kernel_t<avx512_core, dnnl::impl::data_type::bf16>>; |
1807 | template class jit_uni_lrn_kernel_t< |
1808 | jit_uni_lrn_bwd_kernel_t<avx512_core_fp16, dnnl::impl::data_type::f16>>; |
1809 | |
1810 | } // namespace x64 |
1811 | } // namespace cpu |
1812 | } // namespace impl |
1813 | } // namespace dnnl |
1814 | |
1815 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1816 | |