1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <numeric> |
18 | #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | namespace lrn { |
25 | |
26 | template <data_type_t d_type> |
27 | jit_avx512_common_lrn_kernel_fwd_nhwc_t< |
28 | d_type>::jit_avx512_common_lrn_kernel_fwd_nhwc_t(unsigned C, |
29 | prop_kind_t prop_kind, float alpha, float beta, float k, int local_size, |
30 | void *code_ptr, size_t code_size) |
31 | : jit_avx512_common_lrn_kernel_fwd_t<d_type>(prop_kind, alpha, beta, k, |
32 | local_size, code_ptr, code_size, jit_name()) |
33 | , tmp_mask_prev_ {[this]() { |
34 | std::vector<int> v(this->local_size_ / 2); |
35 | std::iota(v.begin(), v.end(), this->zc_ + 2); |
36 | return v; |
37 | }()} |
38 | , tmp_mask_next_ {[this]() { |
39 | std::vector<int> v(this->local_size_ / 2); |
40 | std::iota(v.begin(), v.end(), this->zc_ + 2 + this->local_size_ / 2); |
41 | return v; |
42 | }()} |
43 | , half_ls_ {(local_size - 1) / 2} |
44 | , C(C) {} |
45 | |
46 | template <data_type_t d_type> |
47 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::generate() { |
48 | |
49 | const auto res = std::div(C, 16); |
50 | const auto &C_tail = res.rem; |
51 | const auto &num_full_16c_blocks = res.quot; |
52 | static const auto stack_space = zmm_size * 3; |
53 | |
54 | this->preamble(); |
55 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
56 | if (C_tail) reserve_stack_space(stack_space); |
57 | this->set_up_ker_params(); |
58 | this->execute_compute_loop(num_full_16c_blocks, C_tail); |
59 | if (C_tail) unreserve_stack_space(stack_space); |
60 | this->postamble(); |
61 | } |
62 | |
63 | template <data_type_t d_type> |
64 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::reserve_stack_space( |
65 | std::size_t space) { |
66 | this->sub(rsp, space); |
67 | this->uni_vpxor(zmm4, zmm4, zmm4); |
68 | for (unsigned i = 0; i < 2u; ++i) |
69 | this->vmovups(ptr[rsp + i * zmm_size], zmm4); |
70 | } |
71 | |
72 | template <data_type_t d_type> |
73 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::unreserve_stack_space( |
74 | std::size_t space) { |
75 | this->add(rsp, space); |
76 | } |
77 | |
78 | template <data_type_t d_type> |
79 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::set_up_ker_params() { |
80 | |
81 | #define GET_OFF(field) \ |
82 | offsetof(typename jit_avx512_common_lrn_kernel_fwd_t< \ |
83 | d_type>::jit_args_fwd_t, \ |
84 | field) |
85 | this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]); |
86 | this->mov(this->dst_, ptr[this->param_ + GET_OFF(dst)]); |
87 | if (this->pk_ != prop_kind::forward_inference) { |
88 | this->mov(this->ws0_, ptr[this->param_ + GET_OFF(ws0)]); |
89 | this->mov(this->ws1_, ptr[this->param_ + GET_OFF(ws1)]); |
90 | } |
91 | this->mov(this->mask_, ptr[this->param_ + GET_OFF(mask_ptr)]); |
92 | #undef GET_OFF |
93 | |
94 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
95 | this->vmovq(this->xalpha_, this->imm_addr64_); |
96 | this->vbroadcastss(this->zalpha_, this->xalpha_); |
97 | |
98 | this->mov(this->imm_addr64_, float2int(this->k_)); |
99 | this->vmovq(this->xk_, this->imm_addr64_); |
100 | this->vbroadcastss(this->zk_, this->xk_); |
101 | } |
102 | |
103 | template <data_type_t d_type> |
104 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::execute_compute_loop( |
105 | unsigned num_full_16c_blocks, unsigned C_tail) { |
106 | |
107 | if ((num_full_16c_blocks == 1u && !C_tail) |
108 | || (num_full_16c_blocks == 0u && C_tail)) { |
109 | const auto tail_proc |
110 | = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail; |
111 | compute_loop(across_version::Single, tail_proc, C_tail); |
112 | } else { |
113 | const int begin_end = C_tail ? 1 : 2; |
114 | int middle_16_c_blocks = num_full_16c_blocks == 1 |
115 | ? 0 |
116 | : num_full_16c_blocks - begin_end; |
117 | int LTAIL = 0; |
118 | if (C_tail && middle_16_c_blocks) { |
119 | middle_16_c_blocks -= 1; |
120 | LTAIL = 1; |
121 | } |
122 | |
123 | const int LSREST = middle_16_c_blocks % this->reg_block_; |
124 | const int LS = middle_16_c_blocks - LSREST; |
125 | |
126 | if (LS > 0) this->mov(this->blockC_, LS); |
127 | const auto first_tail_proc = num_full_16c_blocks == 1 |
128 | ? tail_mode::NextTail |
129 | : tail_mode::NoTail; |
130 | compute_loop(across_version::First, first_tail_proc, C_tail); |
131 | increment_loop_params(this->vlen_); |
132 | |
133 | Label lrn_loop; |
134 | |
135 | if (LS > 0) { |
136 | |
137 | this->L(lrn_loop); |
138 | { |
139 | compute_loop(across_version::Middle, tail_mode::NoTail, C_tail, |
140 | this->reg_block_); |
141 | increment_loop_params(this->reg_block_ * this->vlen_); |
142 | this->sub(this->blockC_, this->reg_block_); |
143 | this->cmp(this->blockC_, 0); |
144 | this->jne(lrn_loop, this->T_NEAR); |
145 | } |
146 | } |
147 | |
148 | if (LSREST > 0) { |
149 | compute_loop( |
150 | across_version::Middle, tail_mode::NoTail, C_tail, LSREST); |
151 | increment_loop_params(LSREST * this->vlen_); |
152 | } |
153 | |
154 | if (LTAIL) { |
155 | compute_loop( |
156 | across_version::Middle, tail_mode::NextTail, C_tail, LTAIL); |
157 | increment_loop_params(LTAIL * this->vlen_); |
158 | } |
159 | |
160 | const auto last_tail_proc |
161 | = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail; |
162 | compute_loop(across_version::Last, last_tail_proc, C_tail); |
163 | } |
164 | } |
165 | |
166 | template <data_type_t d_type> |
167 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::increment_loop_params( |
168 | std::size_t offset) { |
169 | |
170 | this->add(this->src_, offset); |
171 | this->add(this->dst_, offset); |
172 | if (this->pk_ != prop_kind::forward_inference) { |
173 | this->add(this->ws0_, offset); |
174 | this->add(this->ws1_, offset); |
175 | } |
176 | } |
177 | |
178 | template <data_type_t d_type> |
179 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::compute_loop( |
180 | across_version version, tail_mode tail_proc, unsigned C_tail, |
181 | int loop_size_param) { |
182 | |
183 | if (tail_proc != tail_mode::NoTail) |
184 | load_data_to_stack(C_tail, version, tail_proc); |
185 | load_compute_data(version, tail_proc, loop_size_param); |
186 | compute(loop_size_param); |
187 | store_compute_data(loop_size_param, tail_proc, C_tail); |
188 | } |
189 | |
190 | template <data_type_t d_type> |
191 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::load_data_to_stack( |
192 | unsigned C_tail, across_version version, tail_mode tail_proc) { |
193 | if (version != across_version::Single) { |
194 | const int previousChunkOffset |
195 | = tail_proc == tail_mode::NextTail ? 0 : -1 * this->vlen_; |
196 | this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_), |
197 | this->EVEX_compress_addr(this->src_, previousChunkOffset)); |
198 | this->vmovups(this->EVEX_compress_addr(rsp, 0), |
199 | this->zreg(0, tmp_load_to_stack_idx_prev_)); |
200 | } |
201 | |
202 | const int tail_src_mem_offset |
203 | = tail_proc == tail_mode::NextTail ? this->vlen_ : 0; |
204 | static constexpr int tail_dst_stack_offset = zmm_size; |
205 | this->load_tail(C_tail, this->src_, tail_src_mem_offset, |
206 | tail_dst_stack_offset, this->tmp_load_to_stack_idx_tail_); |
207 | } |
208 | |
209 | template <data_type_t d_type> |
210 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::load_compute_data( |
211 | across_version version, tail_mode tail_proc, int loop_size_param) { |
212 | |
213 | static constexpr int acc_size = utils::one_of(d_type, bf16, f16) |
214 | ? sizeof(acc_data_bf16_t) |
215 | : sizeof(acc_data_t); |
216 | |
217 | const int loop_size = loop_size_param; |
218 | static constexpr int mask_shift = sizeof(int32_t); |
219 | const auto load_shifted_padded_with_zeros |
220 | = [&](int dstIdx, int srcIdx, int maskTmpIdx, int offset) { |
221 | this->uni_vpxor(this->zreg(0, dstIdx), this->zreg(0, dstIdx), |
222 | this->zreg(0, dstIdx)); |
223 | this->load_data(this->zreg(0, maskTmpIdx), |
224 | this->EVEX_compress_addr(this->mask_, offset), true); |
225 | this->vpermt2ps(this->zreg(0, dstIdx), |
226 | this->zreg(0, maskTmpIdx), this->zreg(0, srcIdx)); |
227 | }; |
228 | |
229 | if (tail_proc == tail_mode::CurrentTail) { |
230 | this->load_data(this->zreg(0, this->zc_), |
231 | this->EVEX_compress_addr(rsp, zmm_size), true); |
232 | } else { |
233 | IRB_LOOP(this->load_data(this->zreg(irb, this->zc_), |
234 | this->EVEX_compress_addr(this->src_, irb * this->vlen_))); |
235 | } |
236 | |
237 | struct entry_t { |
238 | int reg, mask, pos; |
239 | entry_t(int reg, int mask, int pos) |
240 | : reg {reg}, mask {mask}, pos {pos} {} |
241 | }; |
242 | std::vector<entry_t> prev_v; |
243 | prev_v.reserve(this->half_ls_); |
244 | for (int pos = 0; pos < this->half_ls_; ++pos) { |
245 | prev_v.emplace_back(this->z_prev_[pos], this->tmp_mask_prev_[pos], |
246 | this->half_ls_ - pos); |
247 | } |
248 | if (version == across_version::First || version == across_version::Single) { |
249 | for (const auto &entry : prev_v) { |
250 | load_shifted_padded_with_zeros(entry.reg, this->zc_, entry.mask, |
251 | -1 * entry.pos * mask_shift); |
252 | } |
253 | } else { |
254 | if (tail_proc == tail_mode::CurrentTail) { |
255 | for (const auto &entry : prev_v) { |
256 | this->load_data(this->zreg(0, entry.reg), |
257 | this->EVEX_compress_addr(rsp, |
258 | zmm_size - 1 * entry.pos * sizeof(acc_data_t)), |
259 | true); |
260 | } |
261 | } else { |
262 | for (const auto &entry : prev_v) { |
263 | IRB_LOOP(this->load_data(this->zreg(irb, entry.reg), |
264 | this->EVEX_compress_addr(this->src_, |
265 | (irb * this->vlen_) |
266 | - 1 * entry.pos * acc_size))); |
267 | } |
268 | } |
269 | } |
270 | |
271 | std::vector<entry_t> next_v; |
272 | next_v.reserve(this->half_ls_); |
273 | for (int pos = 0; pos < this->half_ls_; ++pos) { |
274 | next_v.emplace_back( |
275 | this->z_next_[pos], this->tmp_mask_next_[pos], pos + 1); |
276 | } |
277 | if (version == across_version::Last || version == across_version::Single) { |
278 | for (const auto &entry : next_v) { |
279 | load_shifted_padded_with_zeros( |
280 | entry.reg, this->zc_, entry.mask, entry.pos * mask_shift); |
281 | } |
282 | } else { |
283 | if (tail_proc == tail_mode::NextTail) { |
284 | for (const auto &entry : next_v) { |
285 | this->load_data(this->zreg(0, entry.reg), |
286 | this->EVEX_compress_addr( |
287 | rsp, entry.pos * sizeof(acc_data_t)), |
288 | true); |
289 | } |
290 | } else { |
291 | for (const auto &entry : next_v) { |
292 | IRB_LOOP(this->load_data(this->zreg(irb, entry.reg), |
293 | this->EVEX_compress_addr(this->src_, |
294 | (irb * this->vlen_) + entry.pos * acc_size))); |
295 | } |
296 | } |
297 | } |
298 | } |
299 | |
300 | template <data_type_t d_type> |
301 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::compute( |
302 | int loop_size_param) { |
303 | |
304 | const int loop_size = loop_size_param; |
305 | |
306 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_), |
307 | this->zreg(irb, this->zc_), this->zreg(irb, this->zc_))); |
308 | |
309 | for (const auto reg : this->z_prev_) |
310 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
311 | this->zreg(irb, reg), this->zreg(irb, reg))); |
312 | for (const auto reg : this->z_next_) |
313 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
314 | this->zreg(irb, reg), this->zreg(irb, reg))); |
315 | |
316 | IRB_LOOP(this->vfmadd132ps( |
317 | this->zreg(irb, this->zsum_), this->zk_, this->zalpha_)); |
318 | IRB_LOOP(this->vmovaps( |
319 | this->zreg(irb, this->zbase_), this->zreg(irb, this->zsum_))); |
320 | |
321 | if (this->beta_ != 1) { |
322 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum2_), |
323 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_))); |
324 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_), |
325 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum2_))); |
326 | |
327 | for (unsigned i = 0; i < 2; ++i) |
328 | IRB_LOOP(this->vsqrtps(this->zreg(irb, this->zsum_), |
329 | this->zreg(irb, this->zsum_))); |
330 | } |
331 | } |
332 | |
333 | template <data_type_t d_type> |
334 | void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::store_compute_data( |
335 | int loop_size_param, tail_mode tail_proc, unsigned C_tail) { |
336 | |
337 | const int loop_size = loop_size_param; |
338 | static const int ytmp = this->zsum2_; |
339 | |
340 | if (this->pk_ != prop_kind::forward_inference) { |
341 | // save intermediate results for lrn backward |
342 | if (tail_proc == tail_mode::CurrentTail) |
343 | this->store_tail(C_tail, this->zreg(0, this->zsum_), this->ws0_, 0, |
344 | 2 * zmm_size, tmp_store_from_stack_idx_tail_); |
345 | else |
346 | IRB_LOOP(this->store_data( |
347 | this->EVEX_compress_addr(this->ws0_, irb * this->vlen_), |
348 | this->zreg(irb, this->zsum_), this->yreg(irb, ytmp))); |
349 | } |
350 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zdst_), |
351 | this->zreg(irb, this->zsrc_), this->zreg(irb, this->zsum_))); |
352 | // storing to dst |
353 | if (tail_proc == tail_mode::CurrentTail) |
354 | this->store_tail(C_tail, this->zreg(0, this->zdst_), this->dst_, 0, |
355 | 2 * zmm_size, tmp_store_from_stack_idx_tail_); |
356 | else |
357 | IRB_LOOP(this->store_data( |
358 | this->EVEX_compress_addr(this->dst_, irb * this->vlen_), |
359 | this->zreg(irb, this->zdst_), this->yreg(irb, ytmp))); |
360 | |
361 | if (this->pk_ != prop_kind::forward_inference) { |
362 | // calculate and save more intermediate results for lrn backward |
363 | /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ |
364 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zsum_), |
365 | this->zreg(irb, this->zdst_), this->zreg(irb, this->zbase_))); |
366 | |
367 | if (tail_proc == tail_mode::CurrentTail) |
368 | this->store_tail(C_tail, this->zreg(0, this->zsum_), this->ws1_, 0, |
369 | 2 * zmm_size, tmp_store_from_stack_idx_tail_); |
370 | else |
371 | IRB_LOOP(this->store_data( |
372 | this->EVEX_compress_addr(this->ws1_, irb * this->vlen_), |
373 | this->zreg(irb, this->zsum_), this->yreg(irb, ytmp))); |
374 | } |
375 | } |
376 | |
377 | template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<f32>; |
378 | template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<bf16>; |
379 | template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<f16>; |
380 | |
381 | } // namespace lrn |
382 | } // namespace x64 |
383 | } // namespace cpu |
384 | } // namespace impl |
385 | } // namespace dnnl |
386 | |