1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <numeric>
18#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24namespace lrn {
25
26template <data_type_t d_type>
27jit_avx512_common_lrn_kernel_fwd_nhwc_t<
28 d_type>::jit_avx512_common_lrn_kernel_fwd_nhwc_t(unsigned C,
29 prop_kind_t prop_kind, float alpha, float beta, float k, int local_size,
30 void *code_ptr, size_t code_size)
31 : jit_avx512_common_lrn_kernel_fwd_t<d_type>(prop_kind, alpha, beta, k,
32 local_size, code_ptr, code_size, jit_name())
33 , tmp_mask_prev_ {[this]() {
34 std::vector<int> v(this->local_size_ / 2);
35 std::iota(v.begin(), v.end(), this->zc_ + 2);
36 return v;
37 }()}
38 , tmp_mask_next_ {[this]() {
39 std::vector<int> v(this->local_size_ / 2);
40 std::iota(v.begin(), v.end(), this->zc_ + 2 + this->local_size_ / 2);
41 return v;
42 }()}
43 , half_ls_ {(local_size - 1) / 2}
44 , C(C) {}
45
46template <data_type_t d_type>
47void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::generate() {
48
49 const auto res = std::div(C, 16);
50 const auto &C_tail = res.rem;
51 const auto &num_full_16c_blocks = res.quot;
52 static const auto stack_space = zmm_size * 3;
53
54 this->preamble();
55 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
56 if (C_tail) reserve_stack_space(stack_space);
57 this->set_up_ker_params();
58 this->execute_compute_loop(num_full_16c_blocks, C_tail);
59 if (C_tail) unreserve_stack_space(stack_space);
60 this->postamble();
61}
62
63template <data_type_t d_type>
64void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::reserve_stack_space(
65 std::size_t space) {
66 this->sub(rsp, space);
67 this->uni_vpxor(zmm4, zmm4, zmm4);
68 for (unsigned i = 0; i < 2u; ++i)
69 this->vmovups(ptr[rsp + i * zmm_size], zmm4);
70}
71
72template <data_type_t d_type>
73void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::unreserve_stack_space(
74 std::size_t space) {
75 this->add(rsp, space);
76}
77
78template <data_type_t d_type>
79void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::set_up_ker_params() {
80
81#define GET_OFF(field) \
82 offsetof(typename jit_avx512_common_lrn_kernel_fwd_t< \
83 d_type>::jit_args_fwd_t, \
84 field)
85 this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]);
86 this->mov(this->dst_, ptr[this->param_ + GET_OFF(dst)]);
87 if (this->pk_ != prop_kind::forward_inference) {
88 this->mov(this->ws0_, ptr[this->param_ + GET_OFF(ws0)]);
89 this->mov(this->ws1_, ptr[this->param_ + GET_OFF(ws1)]);
90 }
91 this->mov(this->mask_, ptr[this->param_ + GET_OFF(mask_ptr)]);
92#undef GET_OFF
93
94 this->mov(this->imm_addr64_, float2int(this->alpha_));
95 this->vmovq(this->xalpha_, this->imm_addr64_);
96 this->vbroadcastss(this->zalpha_, this->xalpha_);
97
98 this->mov(this->imm_addr64_, float2int(this->k_));
99 this->vmovq(this->xk_, this->imm_addr64_);
100 this->vbroadcastss(this->zk_, this->xk_);
101}
102
103template <data_type_t d_type>
104void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::execute_compute_loop(
105 unsigned num_full_16c_blocks, unsigned C_tail) {
106
107 if ((num_full_16c_blocks == 1u && !C_tail)
108 || (num_full_16c_blocks == 0u && C_tail)) {
109 const auto tail_proc
110 = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail;
111 compute_loop(across_version::Single, tail_proc, C_tail);
112 } else {
113 const int begin_end = C_tail ? 1 : 2;
114 int middle_16_c_blocks = num_full_16c_blocks == 1
115 ? 0
116 : num_full_16c_blocks - begin_end;
117 int LTAIL = 0;
118 if (C_tail && middle_16_c_blocks) {
119 middle_16_c_blocks -= 1;
120 LTAIL = 1;
121 }
122
123 const int LSREST = middle_16_c_blocks % this->reg_block_;
124 const int LS = middle_16_c_blocks - LSREST;
125
126 if (LS > 0) this->mov(this->blockC_, LS);
127 const auto first_tail_proc = num_full_16c_blocks == 1
128 ? tail_mode::NextTail
129 : tail_mode::NoTail;
130 compute_loop(across_version::First, first_tail_proc, C_tail);
131 increment_loop_params(this->vlen_);
132
133 Label lrn_loop;
134
135 if (LS > 0) {
136
137 this->L(lrn_loop);
138 {
139 compute_loop(across_version::Middle, tail_mode::NoTail, C_tail,
140 this->reg_block_);
141 increment_loop_params(this->reg_block_ * this->vlen_);
142 this->sub(this->blockC_, this->reg_block_);
143 this->cmp(this->blockC_, 0);
144 this->jne(lrn_loop, this->T_NEAR);
145 }
146 }
147
148 if (LSREST > 0) {
149 compute_loop(
150 across_version::Middle, tail_mode::NoTail, C_tail, LSREST);
151 increment_loop_params(LSREST * this->vlen_);
152 }
153
154 if (LTAIL) {
155 compute_loop(
156 across_version::Middle, tail_mode::NextTail, C_tail, LTAIL);
157 increment_loop_params(LTAIL * this->vlen_);
158 }
159
160 const auto last_tail_proc
161 = C_tail ? tail_mode::CurrentTail : tail_mode::NoTail;
162 compute_loop(across_version::Last, last_tail_proc, C_tail);
163 }
164}
165
166template <data_type_t d_type>
167void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::increment_loop_params(
168 std::size_t offset) {
169
170 this->add(this->src_, offset);
171 this->add(this->dst_, offset);
172 if (this->pk_ != prop_kind::forward_inference) {
173 this->add(this->ws0_, offset);
174 this->add(this->ws1_, offset);
175 }
176}
177
178template <data_type_t d_type>
179void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::compute_loop(
180 across_version version, tail_mode tail_proc, unsigned C_tail,
181 int loop_size_param) {
182
183 if (tail_proc != tail_mode::NoTail)
184 load_data_to_stack(C_tail, version, tail_proc);
185 load_compute_data(version, tail_proc, loop_size_param);
186 compute(loop_size_param);
187 store_compute_data(loop_size_param, tail_proc, C_tail);
188}
189
190template <data_type_t d_type>
191void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::load_data_to_stack(
192 unsigned C_tail, across_version version, tail_mode tail_proc) {
193 if (version != across_version::Single) {
194 const int previousChunkOffset
195 = tail_proc == tail_mode::NextTail ? 0 : -1 * this->vlen_;
196 this->load_data(this->zreg(0, tmp_load_to_stack_idx_prev_),
197 this->EVEX_compress_addr(this->src_, previousChunkOffset));
198 this->vmovups(this->EVEX_compress_addr(rsp, 0),
199 this->zreg(0, tmp_load_to_stack_idx_prev_));
200 }
201
202 const int tail_src_mem_offset
203 = tail_proc == tail_mode::NextTail ? this->vlen_ : 0;
204 static constexpr int tail_dst_stack_offset = zmm_size;
205 this->load_tail(C_tail, this->src_, tail_src_mem_offset,
206 tail_dst_stack_offset, this->tmp_load_to_stack_idx_tail_);
207}
208
209template <data_type_t d_type>
210void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::load_compute_data(
211 across_version version, tail_mode tail_proc, int loop_size_param) {
212
213 static constexpr int acc_size = utils::one_of(d_type, bf16, f16)
214 ? sizeof(acc_data_bf16_t)
215 : sizeof(acc_data_t);
216
217 const int loop_size = loop_size_param;
218 static constexpr int mask_shift = sizeof(int32_t);
219 const auto load_shifted_padded_with_zeros
220 = [&](int dstIdx, int srcIdx, int maskTmpIdx, int offset) {
221 this->uni_vpxor(this->zreg(0, dstIdx), this->zreg(0, dstIdx),
222 this->zreg(0, dstIdx));
223 this->load_data(this->zreg(0, maskTmpIdx),
224 this->EVEX_compress_addr(this->mask_, offset), true);
225 this->vpermt2ps(this->zreg(0, dstIdx),
226 this->zreg(0, maskTmpIdx), this->zreg(0, srcIdx));
227 };
228
229 if (tail_proc == tail_mode::CurrentTail) {
230 this->load_data(this->zreg(0, this->zc_),
231 this->EVEX_compress_addr(rsp, zmm_size), true);
232 } else {
233 IRB_LOOP(this->load_data(this->zreg(irb, this->zc_),
234 this->EVEX_compress_addr(this->src_, irb * this->vlen_)));
235 }
236
237 struct entry_t {
238 int reg, mask, pos;
239 entry_t(int reg, int mask, int pos)
240 : reg {reg}, mask {mask}, pos {pos} {}
241 };
242 std::vector<entry_t> prev_v;
243 prev_v.reserve(this->half_ls_);
244 for (int pos = 0; pos < this->half_ls_; ++pos) {
245 prev_v.emplace_back(this->z_prev_[pos], this->tmp_mask_prev_[pos],
246 this->half_ls_ - pos);
247 }
248 if (version == across_version::First || version == across_version::Single) {
249 for (const auto &entry : prev_v) {
250 load_shifted_padded_with_zeros(entry.reg, this->zc_, entry.mask,
251 -1 * entry.pos * mask_shift);
252 }
253 } else {
254 if (tail_proc == tail_mode::CurrentTail) {
255 for (const auto &entry : prev_v) {
256 this->load_data(this->zreg(0, entry.reg),
257 this->EVEX_compress_addr(rsp,
258 zmm_size - 1 * entry.pos * sizeof(acc_data_t)),
259 true);
260 }
261 } else {
262 for (const auto &entry : prev_v) {
263 IRB_LOOP(this->load_data(this->zreg(irb, entry.reg),
264 this->EVEX_compress_addr(this->src_,
265 (irb * this->vlen_)
266 - 1 * entry.pos * acc_size)));
267 }
268 }
269 }
270
271 std::vector<entry_t> next_v;
272 next_v.reserve(this->half_ls_);
273 for (int pos = 0; pos < this->half_ls_; ++pos) {
274 next_v.emplace_back(
275 this->z_next_[pos], this->tmp_mask_next_[pos], pos + 1);
276 }
277 if (version == across_version::Last || version == across_version::Single) {
278 for (const auto &entry : next_v) {
279 load_shifted_padded_with_zeros(
280 entry.reg, this->zc_, entry.mask, entry.pos * mask_shift);
281 }
282 } else {
283 if (tail_proc == tail_mode::NextTail) {
284 for (const auto &entry : next_v) {
285 this->load_data(this->zreg(0, entry.reg),
286 this->EVEX_compress_addr(
287 rsp, entry.pos * sizeof(acc_data_t)),
288 true);
289 }
290 } else {
291 for (const auto &entry : next_v) {
292 IRB_LOOP(this->load_data(this->zreg(irb, entry.reg),
293 this->EVEX_compress_addr(this->src_,
294 (irb * this->vlen_) + entry.pos * acc_size)));
295 }
296 }
297 }
298}
299
300template <data_type_t d_type>
301void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::compute(
302 int loop_size_param) {
303
304 const int loop_size = loop_size_param;
305
306 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_),
307 this->zreg(irb, this->zc_), this->zreg(irb, this->zc_)));
308
309 for (const auto reg : this->z_prev_)
310 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
311 this->zreg(irb, reg), this->zreg(irb, reg)));
312 for (const auto reg : this->z_next_)
313 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
314 this->zreg(irb, reg), this->zreg(irb, reg)));
315
316 IRB_LOOP(this->vfmadd132ps(
317 this->zreg(irb, this->zsum_), this->zk_, this->zalpha_));
318 IRB_LOOP(this->vmovaps(
319 this->zreg(irb, this->zbase_), this->zreg(irb, this->zsum_)));
320
321 if (this->beta_ != 1) {
322 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum2_),
323 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_)));
324 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_),
325 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum2_)));
326
327 for (unsigned i = 0; i < 2; ++i)
328 IRB_LOOP(this->vsqrtps(this->zreg(irb, this->zsum_),
329 this->zreg(irb, this->zsum_)));
330 }
331}
332
333template <data_type_t d_type>
334void jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>::store_compute_data(
335 int loop_size_param, tail_mode tail_proc, unsigned C_tail) {
336
337 const int loop_size = loop_size_param;
338 static const int ytmp = this->zsum2_;
339
340 if (this->pk_ != prop_kind::forward_inference) {
341 // save intermediate results for lrn backward
342 if (tail_proc == tail_mode::CurrentTail)
343 this->store_tail(C_tail, this->zreg(0, this->zsum_), this->ws0_, 0,
344 2 * zmm_size, tmp_store_from_stack_idx_tail_);
345 else
346 IRB_LOOP(this->store_data(
347 this->EVEX_compress_addr(this->ws0_, irb * this->vlen_),
348 this->zreg(irb, this->zsum_), this->yreg(irb, ytmp)));
349 }
350 IRB_LOOP(this->vdivps(this->zreg(irb, this->zdst_),
351 this->zreg(irb, this->zsrc_), this->zreg(irb, this->zsum_)));
352 // storing to dst
353 if (tail_proc == tail_mode::CurrentTail)
354 this->store_tail(C_tail, this->zreg(0, this->zdst_), this->dst_, 0,
355 2 * zmm_size, tmp_store_from_stack_idx_tail_);
356 else
357 IRB_LOOP(this->store_data(
358 this->EVEX_compress_addr(this->dst_, irb * this->vlen_),
359 this->zreg(irb, this->zdst_), this->yreg(irb, ytmp)));
360
361 if (this->pk_ != prop_kind::forward_inference) {
362 // calculate and save more intermediate results for lrn backward
363 /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
364 IRB_LOOP(this->vdivps(this->zreg(irb, this->zsum_),
365 this->zreg(irb, this->zdst_), this->zreg(irb, this->zbase_)));
366
367 if (tail_proc == tail_mode::CurrentTail)
368 this->store_tail(C_tail, this->zreg(0, this->zsum_), this->ws1_, 0,
369 2 * zmm_size, tmp_store_from_stack_idx_tail_);
370 else
371 IRB_LOOP(this->store_data(
372 this->EVEX_compress_addr(this->ws1_, irb * this->vlen_),
373 this->zreg(irb, this->zsum_), this->yreg(irb, ytmp)));
374 }
375}
376
377template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<f32>;
378template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<bf16>;
379template class jit_avx512_common_lrn_kernel_fwd_nhwc_t<f16>;
380
381} // namespace lrn
382} // namespace x64
383} // namespace cpu
384} // namespace impl
385} // namespace dnnl
386