1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <numeric>
18#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24namespace lrn {
25
26using acc_data_t = float;
27
28template <data_type_t d_type>
29jit_avx512_common_lrn_kernel_bwd_nhwc_t<
30 d_type>::jit_avx512_common_lrn_kernel_bwd_nhwc_t(unsigned C,
31 float alpha, float beta, int local_size, void *code_ptr,
32 size_t code_size)
33 : jit_avx512_common_lrn_kernel_bwd_t<d_type>(
34 alpha, beta, local_size, code_ptr, code_size, jit_name())
35 , tmp_mask_prev_ {[this]() {
36 std::vector<int> v(this->local_size_ / 2);
37 std::iota(v.begin(), v.end(), this->zdiffsrc_ + 2);
38 return v;
39 }()}
40 , tmp_mask_next_ {[this]() {
41 std::vector<int> v(this->local_size_ / 2);
42 std::iota(v.begin(), v.end(),
43 this->zdiffsrc_ + 2 + this->local_size_ / 2);
44 return v;
45 }()}
46 , half_ls_ {(local_size - 1) / 2}
47 , C_(C) {}
48
49template <data_type_t d_type>
50void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::generate() {
51
52 const auto res = std::div(C_, 16);
53 const auto &C_tail = res.rem;
54 const auto &num_full_16c_blocks = res.quot;
55 static const auto stack_space = zmm_size_ * 9;
56
57 this->preamble();
58 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
59 if (C_tail) reserve_stack_space(stack_space);
60 this->set_up_ker_params();
61 this->execute_compute_loop(num_full_16c_blocks, C_tail);
62 if (C_tail) unreserve_stack_space(stack_space);
63
64 this->postamble();
65}
66
67template <data_type_t d_type>
68void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::reserve_stack_space(
69 std::size_t space) {
70 const unsigned maxCounter = (space / zmm_size_) - 1;
71 this->sub(rsp, space);
72 this->uni_vpxor(zmm4, zmm4, zmm4);
73 for (unsigned i = 0; i < maxCounter; ++i)
74 this->vmovups(ptr[rsp + i * zmm_size_], zmm4);
75}
76
77template <data_type_t d_type>
78void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::unreserve_stack_space(
79 std::size_t space) {
80 this->add(rsp, space);
81}
82
83template <data_type_t d_type>
84int jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::get_stack_offset(
85 const Reg64 reg, tail_mode tail_proc) {
86
87 int stack_postion = 0;
88 if (reg == this->diffdst_)
89 stack_postion = 1;
90 else if (reg == this->workspace1_)
91 stack_postion = 3;
92 else if (reg == this->workspace0_)
93 stack_postion = 4;
94 else if (reg == this->src_)
95 stack_postion = 5;
96
97 return zmm_size_
98 * (stack_postion + (tail_proc == tail_mode::NextTail ? -1 : 0));
99}
100
101template <data_type_t d_type>
102void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::load_data_to_stack(
103 unsigned C_tail, across_version version, tail_mode tail_proc) {
104
105 if (version != across_version::Single) {
106 const int previousChunkOffset
107 = tail_proc == tail_mode::NextTail ? 0 : -1 * this->vlen_;
108 this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_),
109 this->EVEX_compress_addr(this->diffdst_, previousChunkOffset));
110 this->vmovups(
111 this->EVEX_compress_addr(rsp,
112 get_stack_offset(this->diffdst_, tail_mode::NextTail)),
113 this->zreg(0, tmp_load_to_stack_idx_prev_));
114
115 this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_),
116 this->EVEX_compress_addr(
117 this->workspace1_, previousChunkOffset));
118 this->vmovups(this->EVEX_compress_addr(rsp,
119 get_stack_offset(
120 this->workspace1_, tail_mode::NextTail)),
121 this->zreg(0, tmp_load_to_stack_idx_prev_));
122 }
123
124 const int tail_src_mem_offset
125 = tail_proc == tail_mode::NextTail ? this->vlen_ : 0;
126 this->load_tail(C_tail, this->diffdst_, tail_src_mem_offset,
127 get_stack_offset(this->diffdst_, tail_mode::CurrentTail),
128 this->tmp_load_to_stack_idx_tail_);
129 this->load_tail(C_tail, this->workspace0_, tail_src_mem_offset,
130 get_stack_offset(this->workspace0_, tail_mode::CurrentTail),
131 this->tmp_load_to_stack_idx_tail_);
132 this->load_tail(C_tail, this->workspace1_, tail_src_mem_offset,
133 get_stack_offset(this->workspace1_, tail_mode::CurrentTail),
134 this->tmp_load_to_stack_idx_tail_);
135 this->load_tail(C_tail, this->src_, tail_src_mem_offset,
136 get_stack_offset(this->src_, tail_mode::CurrentTail),
137 this->tmp_load_to_stack_idx_tail_);
138}
139
140template <data_type_t d_type>
141void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::set_up_ker_params() {
142#define GET_OFF(field) \
143 offsetof(typename jit_avx512_common_lrn_kernel_bwd_t< \
144 d_type>::jit_args_bwd_t, \
145 field)
146 this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]);
147 this->mov(this->diffdst_, ptr[this->param_ + GET_OFF(diff_dst)]);
148 this->mov(this->workspace0_, ptr[this->param_ + GET_OFF(ws0)]);
149 this->mov(this->workspace1_, ptr[this->param_ + GET_OFF(ws1)]);
150 this->mov(this->diffsrc_, ptr[this->param_ + GET_OFF(diff_src)]);
151
152 this->mov(this->mask_, ptr[this->param_ + GET_OFF(mask_ptr)]);
153#undef GET_OFF
154
155 this->mov(this->imm_addr64_, float2int(this->nalphabeta_));
156 this->vmovq(this->xnalphabeta_, this->imm_addr64_);
157 this->vbroadcastss(this->znalphabeta_, this->xnalphabeta_);
158}
159
160template <data_type_t d_type>
161void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::execute_compute_loop(
162 unsigned num_full_16c_blocks, unsigned C_tail) {
163
164 if ((num_full_16c_blocks == 1u && !C_tail)
165 || (num_full_16c_blocks == 0u && C_tail)) {
166 const auto tail_proc
167 = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail;
168 compute_loop(across_version::Single, tail_proc, C_tail);
169 } else {
170 const int begin_end = C_tail ? 1 : 2;
171 int middle_16_c_blocks = num_full_16c_blocks == 1
172 ? 0
173 : num_full_16c_blocks - begin_end;
174 int LTAIL = 0;
175 if (C_tail && middle_16_c_blocks) {
176 middle_16_c_blocks -= 1;
177 LTAIL = 1;
178 }
179
180 const int LSREST = middle_16_c_blocks % this->reg_block_;
181 const int LS = middle_16_c_blocks - LSREST;
182
183 if (LS > 0) this->mov(this->blockC_, LS);
184 const auto first_tail_proc = num_full_16c_blocks == 1
185 ? tail_mode::NextTail
186 : tail_mode::NoTail;
187 compute_loop(across_version::First, first_tail_proc, C_tail);
188 increment_loop_params(this->vlen_);
189
190 Label lrn_loop;
191
192 if (LS > 0) {
193
194 this->L(lrn_loop);
195 {
196 compute_loop(across_version::Middle, tail_mode::NoTail, C_tail,
197 this->reg_block_);
198 increment_loop_params(this->reg_block_ * this->vlen_);
199 this->sub(this->blockC_, this->reg_block_);
200 this->cmp(this->blockC_, 0);
201 this->jne(lrn_loop, this->T_NEAR);
202 }
203 }
204
205 if (LSREST > 0) {
206 compute_loop(
207 across_version::Middle, tail_mode::NoTail, C_tail, LSREST);
208 increment_loop_params(LSREST * this->vlen_);
209 }
210
211 if (LTAIL) {
212 compute_loop(
213 across_version::Middle, tail_mode::NextTail, C_tail, LTAIL);
214 increment_loop_params(LTAIL * this->vlen_);
215 }
216
217 const auto last_tail_proc
218 = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail;
219 compute_loop(across_version::Last, last_tail_proc, C_tail);
220 }
221}
222
223template <data_type_t d_type>
224void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::compute_loop(
225 across_version version, tail_mode tail_proc, unsigned C_tail,
226 int loop_size_param) {
227
228 if (tail_proc != tail_mode::NoTail)
229 load_data_to_stack(C_tail, version, tail_proc);
230 load_compute_data(version, tail_proc, loop_size_param);
231 compute(loop_size_param, tail_proc);
232 store_compute_data(loop_size_param, tail_proc, C_tail);
233}
234
235template <data_type_t d_type>
236void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::compute(
237 int loop_size, tail_mode tail_proc) {
238
239 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
240 this->zreg(irb, this->zdiffsrc_),
241 this->zreg(irb, this->z_prev_[0])));
242 assert(this->zsrc_ == this->z_prev_[0]);
243
244 if (tail_proc == tail_mode::CurrentTail)
245 this->load_data(this->zreg(0, this->zsrc_),
246 this->EVEX_compress_addr(rsp,
247 get_stack_offset(this->src_, tail_mode::CurrentTail)),
248 true);
249 else
250 IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_),
251 this->EVEX_compress_addr(this->src_, irb * this->vlen_)));
252
253 for (unsigned regIdx = 1; regIdx < this->z_prev_.size(); ++regIdx)
254 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
255 this->zreg(irb, this->zdiffsrc_),
256 this->zreg(irb, this->z_prev_[regIdx])));
257 for (const auto reg : this->z_next_)
258 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
259 this->zreg(irb, this->zdiffsrc_), this->zreg(irb, reg)));
260
261 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsrc_),
262 this->zreg(irb, this->zsrc_), this->znalphabeta_));
263
264 if (tail_proc == tail_mode::CurrentTail) {
265 this->load_data(this->zreg(0, this->zws0_),
266 this->EVEX_compress_addr(rsp,
267 get_stack_offset(
268 this->workspace0_, tail_mode::CurrentTail)),
269 true);
270 } else {
271 IRB_LOOP(this->load_data(this->zreg(irb, this->zws0_),
272 this->EVEX_compress_addr(
273 this->workspace0_, irb * this->vlen_)));
274 }
275
276 IRB_LOOP(this->vdivps(this->zreg(irb, this->zdiffdst_),
277 this->zreg(irb, this->zdiffdst_), this->zreg(irb, this->zws0_)));
278 IRB_LOOP(this->vfmadd213ps(this->zreg(irb, this->zdiffsrc_),
279 this->zreg(irb, this->zsrc_), this->zreg(irb, this->zdiffdst_)));
280}
281
282template <data_type_t d_type>
283void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::increment_loop_params(
284 std::size_t offset) {
285 this->add(this->src_, offset);
286 this->add(this->diffsrc_, offset);
287 this->add(this->diffdst_, offset);
288 this->add(this->workspace0_, offset);
289 this->add(this->workspace1_, offset);
290}
291
292template <data_type_t d_type>
293void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::load_compute_data(
294 across_version version, tail_mode tail_proc, int loop_size_param) {
295
296 const int loop_size = loop_size_param;
297 static constexpr int mask_shift = sizeof(int32_t);
298 static constexpr int acc_size = utils::one_of(d_type, bf16, f16) ? 2 : 4;
299 const auto load_shifted_padded_with_zeros
300 = [this](int dstIdx, int srcIdx, int maskTmpIdx, int offset) {
301 this->uni_vpxor(this->zreg(0, dstIdx), this->zreg(0, dstIdx),
302 this->zreg(0, dstIdx));
303 this->load_data(this->zreg(0, maskTmpIdx),
304 this->EVEX_compress_addr(this->mask_, offset), true);
305 this->vpermt2ps(this->zreg(0, dstIdx),
306 this->zreg(0, maskTmpIdx), this->zreg(0, srcIdx));
307 };
308
309 const auto load_ws_diffdst = [&, this](int dstIdx, int offset,
310 tail_mode tail_proc) {
311 if (tail_proc == tail_mode::NoTail) {
312 IRB_LOOP(this->load_data(this->zreg(irb, dstIdx),
313 this->EVEX_compress_addr(
314 this->workspace1_, (irb * this->vlen_) + offset)));
315 } else
316 this->load_data(this->zreg(0, dstIdx),
317 this->EVEX_compress_addr(this->rsp,
318 get_stack_offset(this->workspace1_, tail_proc)
319 + offset),
320 true);
321
322 if (utils::one_of(d_type, bf16, f16)
323 || tail_proc != tail_mode::NoTail) {
324 if (tail_proc == tail_mode::NoTail) {
325 IRB_LOOP(this->load_data(this->zreg(irb, this->z_tmp_),
326 this->EVEX_compress_addr(
327 this->diffdst_, (irb * this->vlen_) + offset)));
328 } else
329 this->load_data(this->zreg(0, this->z_tmp_),
330 this->EVEX_compress_addr(this->rsp,
331 get_stack_offset(this->diffdst_, tail_proc)
332 + offset),
333 true);
334
335 IRB_LOOP(this->vmulps(this->zreg(irb, dstIdx),
336 this->zreg(irb, this->z_tmp_), this->zreg(irb, dstIdx)));
337 } else {
338 IRB_LOOP(this->vmulps(this->zreg(irb, dstIdx),
339 this->zreg(irb, dstIdx),
340 this->EVEX_compress_addr(
341 this->diffdst_, (irb * this->vlen_) + offset)));
342 }
343 };
344
345 if (tail_proc == tail_mode::CurrentTail) {
346 this->load_data(this->zreg(0, this->zdiffsrc_),
347 this->EVEX_compress_addr(
348 rsp, get_stack_offset(this->workspace1_, tail_proc)),
349 true);
350 this->load_data(this->zreg(0, this->zdiffdst_),
351 this->EVEX_compress_addr(
352 rsp, get_stack_offset(this->diffdst_, tail_proc)),
353 true);
354 } else {
355 IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffsrc_),
356 this->EVEX_compress_addr(
357 this->workspace1_, irb * this->vlen_)));
358 IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffdst_),
359 this->EVEX_compress_addr(this->diffdst_, irb * this->vlen_)));
360 }
361
362 IRB_LOOP(this->vmulps(this->zreg(irb, this->zdiffsrc_),
363 this->zreg(irb, this->zdiffdst_),
364 this->zreg(irb, this->zdiffsrc_)));
365
366 int reg, mask, pos;
367 std::vector<std::tuple<int, int, int>> prev_v;
368 prev_v.reserve(this->half_ls_);
369 for (int pos = 0; pos < this->half_ls_; ++pos) {
370 prev_v.emplace_back(this->z_prev_[pos], this->tmp_mask_prev_[pos],
371 this->half_ls_ - pos);
372 };
373 if (version == across_version::First || version == across_version::Single) {
374 for (const auto &reg_mask_pos : prev_v) {
375 std::tie(reg, mask, pos) = reg_mask_pos;
376 load_shifted_padded_with_zeros(
377 reg, this->zdiffsrc_, mask, -1 * pos * mask_shift);
378 }
379 } else {
380 for (const auto &reg_mask_pos : prev_v) {
381 std::tie(reg, mask, pos) = reg_mask_pos;
382 IRB_LOOP(load_ws_diffdst(reg, -1 * pos * acc_size,
383 tail_proc == tail_mode::CurrentTail ? tail_mode::CurrentTail
384 : tail_mode::NoTail));
385 }
386 }
387
388 std::vector<std::tuple<int, int, int>> next_v;
389 next_v.reserve(this->half_ls_);
390 for (int pos = 0; pos < this->half_ls_; ++pos) {
391 next_v.emplace_back(
392 this->z_next_[pos], this->tmp_mask_next_[pos], pos + 1);
393 }
394 if (version == across_version::Last || version == across_version::Single) {
395 for (const auto &reg_mask_pos : next_v) {
396 std::tie(reg, mask, pos) = reg_mask_pos;
397 load_shifted_padded_with_zeros(
398 reg, this->zdiffsrc_, mask, pos * mask_shift);
399 }
400 } else {
401 for (const auto &reg_mask_pos : next_v) {
402 std::tie(reg, mask, pos) = reg_mask_pos;
403 IRB_LOOP(load_ws_diffdst(reg, pos * acc_size,
404 tail_proc == tail_mode::NextTail ? tail_mode::NextTail
405 : tail_mode::NoTail));
406 }
407 }
408}
409
410template <data_type_t d_type>
411void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::store_compute_data(
412 int loop_size_param, tail_mode tail_proc, unsigned C_tail) {
413 const int loop_size = loop_size_param;
414
415 if (tail_proc == tail_mode::CurrentTail) {
416 this->store_tail(C_tail, this->zreg(0, this->zdiffsrc_), this->diffsrc_,
417 0, 8 * zmm_size_, tmp_store_from_stack_idx_tail_);
418 } else {
419 Label unaligned_store, end_store;
420 this->test(this->diffsrc_, this->vlen_ - 1);
421 this->jnz(unaligned_store, this->T_NEAR);
422 IRB_LOOP(this->store_data(true,
423 this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_),
424 this->zreg(irb, this->zdiffsrc_)));
425 this->jmp(end_store, this->T_NEAR);
426 this->L(unaligned_store);
427 {
428 IRB_LOOP(this->store_data(false,
429 this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_),
430 this->zreg(irb, this->zdiffsrc_)));
431 }
432 this->L(end_store);
433 }
434}
435
436template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<f32>;
437template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<bf16>;
438template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<f16>;
439
440} // namespace lrn
441} // namespace x64
442} // namespace cpu
443} // namespace impl
444} // namespace dnnl
445