1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <numeric> |
18 | #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | namespace lrn { |
25 | |
26 | using acc_data_t = float; |
27 | |
28 | template <data_type_t d_type> |
29 | jit_avx512_common_lrn_kernel_bwd_nhwc_t< |
30 | d_type>::jit_avx512_common_lrn_kernel_bwd_nhwc_t(unsigned C, |
31 | float alpha, float beta, int local_size, void *code_ptr, |
32 | size_t code_size) |
33 | : jit_avx512_common_lrn_kernel_bwd_t<d_type>( |
34 | alpha, beta, local_size, code_ptr, code_size, jit_name()) |
35 | , tmp_mask_prev_ {[this]() { |
36 | std::vector<int> v(this->local_size_ / 2); |
37 | std::iota(v.begin(), v.end(), this->zdiffsrc_ + 2); |
38 | return v; |
39 | }()} |
40 | , tmp_mask_next_ {[this]() { |
41 | std::vector<int> v(this->local_size_ / 2); |
42 | std::iota(v.begin(), v.end(), |
43 | this->zdiffsrc_ + 2 + this->local_size_ / 2); |
44 | return v; |
45 | }()} |
46 | , half_ls_ {(local_size - 1) / 2} |
47 | , C_(C) {} |
48 | |
49 | template <data_type_t d_type> |
50 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::generate() { |
51 | |
52 | const auto res = std::div(C_, 16); |
53 | const auto &C_tail = res.rem; |
54 | const auto &num_full_16c_blocks = res.quot; |
55 | static const auto stack_space = zmm_size_ * 9; |
56 | |
57 | this->preamble(); |
58 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
59 | if (C_tail) reserve_stack_space(stack_space); |
60 | this->set_up_ker_params(); |
61 | this->execute_compute_loop(num_full_16c_blocks, C_tail); |
62 | if (C_tail) unreserve_stack_space(stack_space); |
63 | |
64 | this->postamble(); |
65 | } |
66 | |
67 | template <data_type_t d_type> |
68 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::reserve_stack_space( |
69 | std::size_t space) { |
70 | const unsigned maxCounter = (space / zmm_size_) - 1; |
71 | this->sub(rsp, space); |
72 | this->uni_vpxor(zmm4, zmm4, zmm4); |
73 | for (unsigned i = 0; i < maxCounter; ++i) |
74 | this->vmovups(ptr[rsp + i * zmm_size_], zmm4); |
75 | } |
76 | |
77 | template <data_type_t d_type> |
78 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::unreserve_stack_space( |
79 | std::size_t space) { |
80 | this->add(rsp, space); |
81 | } |
82 | |
83 | template <data_type_t d_type> |
84 | int jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::get_stack_offset( |
85 | const Reg64 reg, tail_mode tail_proc) { |
86 | |
87 | int stack_postion = 0; |
88 | if (reg == this->diffdst_) |
89 | stack_postion = 1; |
90 | else if (reg == this->workspace1_) |
91 | stack_postion = 3; |
92 | else if (reg == this->workspace0_) |
93 | stack_postion = 4; |
94 | else if (reg == this->src_) |
95 | stack_postion = 5; |
96 | |
97 | return zmm_size_ |
98 | * (stack_postion + (tail_proc == tail_mode::NextTail ? -1 : 0)); |
99 | } |
100 | |
101 | template <data_type_t d_type> |
102 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::load_data_to_stack( |
103 | unsigned C_tail, across_version version, tail_mode tail_proc) { |
104 | |
105 | if (version != across_version::Single) { |
106 | const int previousChunkOffset |
107 | = tail_proc == tail_mode::NextTail ? 0 : -1 * this->vlen_; |
108 | this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_), |
109 | this->EVEX_compress_addr(this->diffdst_, previousChunkOffset)); |
110 | this->vmovups( |
111 | this->EVEX_compress_addr(rsp, |
112 | get_stack_offset(this->diffdst_, tail_mode::NextTail)), |
113 | this->zreg(0, tmp_load_to_stack_idx_prev_)); |
114 | |
115 | this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_), |
116 | this->EVEX_compress_addr( |
117 | this->workspace1_, previousChunkOffset)); |
118 | this->vmovups(this->EVEX_compress_addr(rsp, |
119 | get_stack_offset( |
120 | this->workspace1_, tail_mode::NextTail)), |
121 | this->zreg(0, tmp_load_to_stack_idx_prev_)); |
122 | } |
123 | |
124 | const int tail_src_mem_offset |
125 | = tail_proc == tail_mode::NextTail ? this->vlen_ : 0; |
126 | this->load_tail(C_tail, this->diffdst_, tail_src_mem_offset, |
127 | get_stack_offset(this->diffdst_, tail_mode::CurrentTail), |
128 | this->tmp_load_to_stack_idx_tail_); |
129 | this->load_tail(C_tail, this->workspace0_, tail_src_mem_offset, |
130 | get_stack_offset(this->workspace0_, tail_mode::CurrentTail), |
131 | this->tmp_load_to_stack_idx_tail_); |
132 | this->load_tail(C_tail, this->workspace1_, tail_src_mem_offset, |
133 | get_stack_offset(this->workspace1_, tail_mode::CurrentTail), |
134 | this->tmp_load_to_stack_idx_tail_); |
135 | this->load_tail(C_tail, this->src_, tail_src_mem_offset, |
136 | get_stack_offset(this->src_, tail_mode::CurrentTail), |
137 | this->tmp_load_to_stack_idx_tail_); |
138 | } |
139 | |
140 | template <data_type_t d_type> |
141 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::set_up_ker_params() { |
142 | #define GET_OFF(field) \ |
143 | offsetof(typename jit_avx512_common_lrn_kernel_bwd_t< \ |
144 | d_type>::jit_args_bwd_t, \ |
145 | field) |
146 | this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]); |
147 | this->mov(this->diffdst_, ptr[this->param_ + GET_OFF(diff_dst)]); |
148 | this->mov(this->workspace0_, ptr[this->param_ + GET_OFF(ws0)]); |
149 | this->mov(this->workspace1_, ptr[this->param_ + GET_OFF(ws1)]); |
150 | this->mov(this->diffsrc_, ptr[this->param_ + GET_OFF(diff_src)]); |
151 | |
152 | this->mov(this->mask_, ptr[this->param_ + GET_OFF(mask_ptr)]); |
153 | #undef GET_OFF |
154 | |
155 | this->mov(this->imm_addr64_, float2int(this->nalphabeta_)); |
156 | this->vmovq(this->xnalphabeta_, this->imm_addr64_); |
157 | this->vbroadcastss(this->znalphabeta_, this->xnalphabeta_); |
158 | } |
159 | |
160 | template <data_type_t d_type> |
161 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::execute_compute_loop( |
162 | unsigned num_full_16c_blocks, unsigned C_tail) { |
163 | |
164 | if ((num_full_16c_blocks == 1u && !C_tail) |
165 | || (num_full_16c_blocks == 0u && C_tail)) { |
166 | const auto tail_proc |
167 | = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail; |
168 | compute_loop(across_version::Single, tail_proc, C_tail); |
169 | } else { |
170 | const int begin_end = C_tail ? 1 : 2; |
171 | int middle_16_c_blocks = num_full_16c_blocks == 1 |
172 | ? 0 |
173 | : num_full_16c_blocks - begin_end; |
174 | int LTAIL = 0; |
175 | if (C_tail && middle_16_c_blocks) { |
176 | middle_16_c_blocks -= 1; |
177 | LTAIL = 1; |
178 | } |
179 | |
180 | const int LSREST = middle_16_c_blocks % this->reg_block_; |
181 | const int LS = middle_16_c_blocks - LSREST; |
182 | |
183 | if (LS > 0) this->mov(this->blockC_, LS); |
184 | const auto first_tail_proc = num_full_16c_blocks == 1 |
185 | ? tail_mode::NextTail |
186 | : tail_mode::NoTail; |
187 | compute_loop(across_version::First, first_tail_proc, C_tail); |
188 | increment_loop_params(this->vlen_); |
189 | |
190 | Label lrn_loop; |
191 | |
192 | if (LS > 0) { |
193 | |
194 | this->L(lrn_loop); |
195 | { |
196 | compute_loop(across_version::Middle, tail_mode::NoTail, C_tail, |
197 | this->reg_block_); |
198 | increment_loop_params(this->reg_block_ * this->vlen_); |
199 | this->sub(this->blockC_, this->reg_block_); |
200 | this->cmp(this->blockC_, 0); |
201 | this->jne(lrn_loop, this->T_NEAR); |
202 | } |
203 | } |
204 | |
205 | if (LSREST > 0) { |
206 | compute_loop( |
207 | across_version::Middle, tail_mode::NoTail, C_tail, LSREST); |
208 | increment_loop_params(LSREST * this->vlen_); |
209 | } |
210 | |
211 | if (LTAIL) { |
212 | compute_loop( |
213 | across_version::Middle, tail_mode::NextTail, C_tail, LTAIL); |
214 | increment_loop_params(LTAIL * this->vlen_); |
215 | } |
216 | |
217 | const auto last_tail_proc |
218 | = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail; |
219 | compute_loop(across_version::Last, last_tail_proc, C_tail); |
220 | } |
221 | } |
222 | |
223 | template <data_type_t d_type> |
224 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::compute_loop( |
225 | across_version version, tail_mode tail_proc, unsigned C_tail, |
226 | int loop_size_param) { |
227 | |
228 | if (tail_proc != tail_mode::NoTail) |
229 | load_data_to_stack(C_tail, version, tail_proc); |
230 | load_compute_data(version, tail_proc, loop_size_param); |
231 | compute(loop_size_param, tail_proc); |
232 | store_compute_data(loop_size_param, tail_proc, C_tail); |
233 | } |
234 | |
235 | template <data_type_t d_type> |
236 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::compute( |
237 | int loop_size, tail_mode tail_proc) { |
238 | |
239 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
240 | this->zreg(irb, this->zdiffsrc_), |
241 | this->zreg(irb, this->z_prev_[0]))); |
242 | assert(this->zsrc_ == this->z_prev_[0]); |
243 | |
244 | if (tail_proc == tail_mode::CurrentTail) |
245 | this->load_data(this->zreg(0, this->zsrc_), |
246 | this->EVEX_compress_addr(rsp, |
247 | get_stack_offset(this->src_, tail_mode::CurrentTail)), |
248 | true); |
249 | else |
250 | IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_), |
251 | this->EVEX_compress_addr(this->src_, irb * this->vlen_))); |
252 | |
253 | for (unsigned regIdx = 1; regIdx < this->z_prev_.size(); ++regIdx) |
254 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
255 | this->zreg(irb, this->zdiffsrc_), |
256 | this->zreg(irb, this->z_prev_[regIdx]))); |
257 | for (const auto reg : this->z_next_) |
258 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
259 | this->zreg(irb, this->zdiffsrc_), this->zreg(irb, reg))); |
260 | |
261 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsrc_), |
262 | this->zreg(irb, this->zsrc_), this->znalphabeta_)); |
263 | |
264 | if (tail_proc == tail_mode::CurrentTail) { |
265 | this->load_data(this->zreg(0, this->zws0_), |
266 | this->EVEX_compress_addr(rsp, |
267 | get_stack_offset( |
268 | this->workspace0_, tail_mode::CurrentTail)), |
269 | true); |
270 | } else { |
271 | IRB_LOOP(this->load_data(this->zreg(irb, this->zws0_), |
272 | this->EVEX_compress_addr( |
273 | this->workspace0_, irb * this->vlen_))); |
274 | } |
275 | |
276 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zdiffdst_), |
277 | this->zreg(irb, this->zdiffdst_), this->zreg(irb, this->zws0_))); |
278 | IRB_LOOP(this->vfmadd213ps(this->zreg(irb, this->zdiffsrc_), |
279 | this->zreg(irb, this->zsrc_), this->zreg(irb, this->zdiffdst_))); |
280 | } |
281 | |
282 | template <data_type_t d_type> |
283 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::increment_loop_params( |
284 | std::size_t offset) { |
285 | this->add(this->src_, offset); |
286 | this->add(this->diffsrc_, offset); |
287 | this->add(this->diffdst_, offset); |
288 | this->add(this->workspace0_, offset); |
289 | this->add(this->workspace1_, offset); |
290 | } |
291 | |
292 | template <data_type_t d_type> |
293 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::load_compute_data( |
294 | across_version version, tail_mode tail_proc, int loop_size_param) { |
295 | |
296 | const int loop_size = loop_size_param; |
297 | static constexpr int mask_shift = sizeof(int32_t); |
298 | static constexpr int acc_size = utils::one_of(d_type, bf16, f16) ? 2 : 4; |
299 | const auto load_shifted_padded_with_zeros |
300 | = [this](int dstIdx, int srcIdx, int maskTmpIdx, int offset) { |
301 | this->uni_vpxor(this->zreg(0, dstIdx), this->zreg(0, dstIdx), |
302 | this->zreg(0, dstIdx)); |
303 | this->load_data(this->zreg(0, maskTmpIdx), |
304 | this->EVEX_compress_addr(this->mask_, offset), true); |
305 | this->vpermt2ps(this->zreg(0, dstIdx), |
306 | this->zreg(0, maskTmpIdx), this->zreg(0, srcIdx)); |
307 | }; |
308 | |
309 | const auto load_ws_diffdst = [&, this](int dstIdx, int offset, |
310 | tail_mode tail_proc) { |
311 | if (tail_proc == tail_mode::NoTail) { |
312 | IRB_LOOP(this->load_data(this->zreg(irb, dstIdx), |
313 | this->EVEX_compress_addr( |
314 | this->workspace1_, (irb * this->vlen_) + offset))); |
315 | } else |
316 | this->load_data(this->zreg(0, dstIdx), |
317 | this->EVEX_compress_addr(this->rsp, |
318 | get_stack_offset(this->workspace1_, tail_proc) |
319 | + offset), |
320 | true); |
321 | |
322 | if (utils::one_of(d_type, bf16, f16) |
323 | || tail_proc != tail_mode::NoTail) { |
324 | if (tail_proc == tail_mode::NoTail) { |
325 | IRB_LOOP(this->load_data(this->zreg(irb, this->z_tmp_), |
326 | this->EVEX_compress_addr( |
327 | this->diffdst_, (irb * this->vlen_) + offset))); |
328 | } else |
329 | this->load_data(this->zreg(0, this->z_tmp_), |
330 | this->EVEX_compress_addr(this->rsp, |
331 | get_stack_offset(this->diffdst_, tail_proc) |
332 | + offset), |
333 | true); |
334 | |
335 | IRB_LOOP(this->vmulps(this->zreg(irb, dstIdx), |
336 | this->zreg(irb, this->z_tmp_), this->zreg(irb, dstIdx))); |
337 | } else { |
338 | IRB_LOOP(this->vmulps(this->zreg(irb, dstIdx), |
339 | this->zreg(irb, dstIdx), |
340 | this->EVEX_compress_addr( |
341 | this->diffdst_, (irb * this->vlen_) + offset))); |
342 | } |
343 | }; |
344 | |
345 | if (tail_proc == tail_mode::CurrentTail) { |
346 | this->load_data(this->zreg(0, this->zdiffsrc_), |
347 | this->EVEX_compress_addr( |
348 | rsp, get_stack_offset(this->workspace1_, tail_proc)), |
349 | true); |
350 | this->load_data(this->zreg(0, this->zdiffdst_), |
351 | this->EVEX_compress_addr( |
352 | rsp, get_stack_offset(this->diffdst_, tail_proc)), |
353 | true); |
354 | } else { |
355 | IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffsrc_), |
356 | this->EVEX_compress_addr( |
357 | this->workspace1_, irb * this->vlen_))); |
358 | IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffdst_), |
359 | this->EVEX_compress_addr(this->diffdst_, irb * this->vlen_))); |
360 | } |
361 | |
362 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zdiffsrc_), |
363 | this->zreg(irb, this->zdiffdst_), |
364 | this->zreg(irb, this->zdiffsrc_))); |
365 | |
366 | int reg, mask, pos; |
367 | std::vector<std::tuple<int, int, int>> prev_v; |
368 | prev_v.reserve(this->half_ls_); |
369 | for (int pos = 0; pos < this->half_ls_; ++pos) { |
370 | prev_v.emplace_back(this->z_prev_[pos], this->tmp_mask_prev_[pos], |
371 | this->half_ls_ - pos); |
372 | }; |
373 | if (version == across_version::First || version == across_version::Single) { |
374 | for (const auto ®_mask_pos : prev_v) { |
375 | std::tie(reg, mask, pos) = reg_mask_pos; |
376 | load_shifted_padded_with_zeros( |
377 | reg, this->zdiffsrc_, mask, -1 * pos * mask_shift); |
378 | } |
379 | } else { |
380 | for (const auto ®_mask_pos : prev_v) { |
381 | std::tie(reg, mask, pos) = reg_mask_pos; |
382 | IRB_LOOP(load_ws_diffdst(reg, -1 * pos * acc_size, |
383 | tail_proc == tail_mode::CurrentTail ? tail_mode::CurrentTail |
384 | : tail_mode::NoTail)); |
385 | } |
386 | } |
387 | |
388 | std::vector<std::tuple<int, int, int>> next_v; |
389 | next_v.reserve(this->half_ls_); |
390 | for (int pos = 0; pos < this->half_ls_; ++pos) { |
391 | next_v.emplace_back( |
392 | this->z_next_[pos], this->tmp_mask_next_[pos], pos + 1); |
393 | } |
394 | if (version == across_version::Last || version == across_version::Single) { |
395 | for (const auto ®_mask_pos : next_v) { |
396 | std::tie(reg, mask, pos) = reg_mask_pos; |
397 | load_shifted_padded_with_zeros( |
398 | reg, this->zdiffsrc_, mask, pos * mask_shift); |
399 | } |
400 | } else { |
401 | for (const auto ®_mask_pos : next_v) { |
402 | std::tie(reg, mask, pos) = reg_mask_pos; |
403 | IRB_LOOP(load_ws_diffdst(reg, pos * acc_size, |
404 | tail_proc == tail_mode::NextTail ? tail_mode::NextTail |
405 | : tail_mode::NoTail)); |
406 | } |
407 | } |
408 | } |
409 | |
410 | template <data_type_t d_type> |
411 | void jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>::store_compute_data( |
412 | int loop_size_param, tail_mode tail_proc, unsigned C_tail) { |
413 | const int loop_size = loop_size_param; |
414 | |
415 | if (tail_proc == tail_mode::CurrentTail) { |
416 | this->store_tail(C_tail, this->zreg(0, this->zdiffsrc_), this->diffsrc_, |
417 | 0, 8 * zmm_size_, tmp_store_from_stack_idx_tail_); |
418 | } else { |
419 | Label unaligned_store, end_store; |
420 | this->test(this->diffsrc_, this->vlen_ - 1); |
421 | this->jnz(unaligned_store, this->T_NEAR); |
422 | IRB_LOOP(this->store_data(true, |
423 | this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_), |
424 | this->zreg(irb, this->zdiffsrc_))); |
425 | this->jmp(end_store, this->T_NEAR); |
426 | this->L(unaligned_store); |
427 | { |
428 | IRB_LOOP(this->store_data(false, |
429 | this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_), |
430 | this->zreg(irb, this->zdiffsrc_))); |
431 | } |
432 | this->L(end_store); |
433 | } |
434 | } |
435 | |
436 | template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<f32>; |
437 | template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<bf16>; |
438 | template class jit_avx512_common_lrn_kernel_bwd_nhwc_t<f16>; |
439 | |
440 | } // namespace lrn |
441 | } // namespace x64 |
442 | } // namespace cpu |
443 | } // namespace impl |
444 | } // namespace dnnl |
445 | |