1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cmath> |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/lrn/jit_uni_lrn.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::format_tag; |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::utils; |
33 | |
34 | static constexpr int MAX_LOCAL_SIZE = 32u; |
35 | |
36 | static dnnl_dim_t compute_n_summands( |
37 | dnnl_dim_t size, int ndims, const dnnl_alg_kind_t &alg_kind) { |
38 | return alg_kind == alg_kind::lrn_across_channels |
39 | ? size |
40 | : std::pow(size, ndims - 2); |
41 | }; |
42 | |
43 | template <cpu_isa_t isa, data_type_t d_type> |
44 | jit_uni_lrn_fwd_t<isa, d_type>::jit_uni_lrn_fwd_t(const pd_t *apd) |
45 | : primitive_t(apd) |
46 | , ker_(nullptr) |
47 | , ker_first_(nullptr) |
48 | , ker_last_(nullptr) {} |
49 | |
50 | template <cpu_isa_t isa, data_type_t d_type> |
51 | jit_uni_lrn_fwd_t<isa, d_type>::~jit_uni_lrn_fwd_t() = default; |
52 | |
53 | template <cpu_isa_t isa, data_type_t d_type> |
54 | status_t jit_uni_lrn_fwd_t<isa, d_type>::init(engine_t *engine) { |
55 | using namespace alg_kind; |
56 | |
57 | const int C = pd()->C(); |
58 | const int H = pd()->H(); |
59 | const int W = pd()->W(); |
60 | const int ndims = memory_desc_wrapper(pd()->src_md()).ndims(); |
61 | const int ls = pd()->desc()->local_size; |
62 | const float K = pd()->desc()->lrn_k; |
63 | const auto pk = pd()->desc()->prop_kind; |
64 | const auto ak = pd()->desc()->alg_kind; |
65 | const auto dat_tag = pd()->dat_tag_; |
66 | const float A = pd()->desc()->lrn_alpha / compute_n_summands(ls, ndims, ak); |
67 | |
68 | if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { |
69 | ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
70 | nchw8c_across_t(H, W, 0), A, K, pk); |
71 | ker_first_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
72 | nchw8c_across_t(H, W, -1), A, K, pk); |
73 | ker_last_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
74 | nchw8c_across_t(H, W, +1), A, K, pk); |
75 | } else if (one_of(dat_tag, nhwc, nChw8c, nChw16c) |
76 | && ak == lrn_within_channel) { |
77 | |
78 | ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
79 | within_config_t(H, W, C, ls, dat_tag), A, K, pk); |
80 | } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { |
81 | ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
82 | nchw_across_t(C, H * W, 0), A, K, pk); |
83 | const int remind = (H * W) % VECTOR_LENGTH; |
84 | if (remind != 0) { |
85 | ker_last_ |
86 | = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
87 | nchw_across_t(C, H * W, remind), A, K, pk); |
88 | } |
89 | } else { |
90 | ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>( |
91 | nhwc_across_t(C), A, K, pk); |
92 | } |
93 | CHECK(ker_->create_kernel()); |
94 | if (ker_first_) CHECK(ker_first_->create_kernel()); |
95 | if (ker_last_) CHECK(ker_last_->create_kernel()); |
96 | return status::success; |
97 | } |
98 | |
99 | template <cpu_isa_t isa, data_type_t d_type> |
100 | status_t jit_uni_lrn_fwd_t<isa, d_type>::execute_forward( |
101 | const exec_ctx_t &ctx) const { |
102 | using namespace alg_kind; |
103 | |
104 | status_t status = status::success; |
105 | |
106 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
107 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
108 | CHECK(status); |
109 | auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); |
110 | CHECK(status); |
111 | |
112 | const int N = pd()->MB(); |
113 | const int C = pd()->C(); |
114 | const int HW = pd()->H() * pd()->W(); |
115 | const int ls = pd()->desc()->local_size; |
116 | |
117 | const auto ak = pd()->desc()->alg_kind; |
118 | const auto dat_tag = pd()->dat_tag_; |
119 | const auto ker_first = ker_first_.get(); |
120 | const auto ker = ker_.get(); |
121 | const auto ker_last = ker_last_.get(); |
122 | |
123 | if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { |
124 | parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c8) { |
125 | const auto offset = n * HW * C + c8 * HW * VECTOR_LENGTH; |
126 | auto ws_ptr = ws ? &ws[offset] : nullptr; |
127 | jit_args_fwd_t args {&src[offset], &dst[offset], ws_ptr, nullptr}; |
128 | if (c8 == 0) |
129 | (*ker_first)(&args); |
130 | else if (c8 == C / VECTOR_LENGTH - 1) |
131 | (*ker_last)(&args); |
132 | else |
133 | (*ker)(&args); |
134 | }); |
135 | } else if (one_of(dat_tag, nhwc, nChw8c, nChw16c) |
136 | && ak == lrn_within_channel) { |
137 | parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c) { |
138 | const std::size_t offset = dat_tag == nhwc |
139 | ? n * HW * C + c * VECTOR_LENGTH |
140 | : n * HW * C + c * HW * VECTOR_LENGTH; |
141 | auto ws0_ptr = ws ? &ws[offset] : nullptr; |
142 | auto ws1_ptr = ws ? &ws[offset + N * C * HW] : nullptr; |
143 | jit_args_fwd_t args {&src[offset], &dst[offset], ws0_ptr, ws1_ptr}; |
144 | (*ker)(&args); |
145 | }); |
146 | } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { |
147 | parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, |
148 | [&](dim_t n, dim_t hw8) { |
149 | const auto offset = n * HW * C + hw8 * VECTOR_LENGTH; |
150 | auto ws0_ptr = ws ? &ws[offset] : nullptr; |
151 | jit_args_fwd_t args { |
152 | &src[offset], &dst[offset], ws0_ptr, nullptr}; |
153 | |
154 | if ((hw8 + 1) * VECTOR_LENGTH > HW) |
155 | (*ker_last)(&args); |
156 | else |
157 | (*ker)(&args); |
158 | }); |
159 | } else { // nhwc |
160 | parallel_nd(N, HW, [&](dim_t n, dim_t hw) { |
161 | const auto offset = n * HW * C + hw * C; |
162 | auto ws_ptr = ws ? &ws[offset] : nullptr; |
163 | jit_args_fwd_t args {&src[offset], &dst[offset], ws_ptr, nullptr}; |
164 | (*ker)(&args); |
165 | }); |
166 | } |
167 | return status::success; |
168 | } |
169 | |
170 | template <cpu_isa_t isa, data_type_t d_type> |
171 | status_t jit_uni_lrn_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) { |
172 | using namespace prop_kind; |
173 | using namespace alg_kind; |
174 | |
175 | const memory_desc_wrapper src_d(src_md()); |
176 | const memory_desc_wrapper dst_d(dst_md()); |
177 | |
178 | const bool ok = is_fwd() && mayiuse(isa) && !has_zero_dim_memory() |
179 | && everyone_is(d_type, src_d.data_type(), dst_d.data_type()) |
180 | && attr()->has_default_values() && set_default_formats_common() |
181 | && src_d == dst_d && src_d.ndims() == 4 |
182 | && src_d.dims()[1] % VECTOR_LENGTH == 0 |
183 | && src_d.dims()[1] >= 2 * VECTOR_LENGTH && desc()->lrn_beta == 0.75; |
184 | if (!ok) return unimplemented; |
185 | |
186 | dat_tag_ = memory_desc_matches_one_of_tag( |
187 | *src_md(), nChw16c, nChw8c, nchw, nhwc); |
188 | |
189 | const int HW = src_d.dims()[2] * src_d.dims()[3]; |
190 | |
191 | const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels |
192 | && desc()->local_size == 5 && one_of(dat_tag_, nChw8c, nchw, nhwc) |
193 | && everyone_is(data_type::f32, src_d.data_type()) |
194 | /* SSE41: prevent loads smaller than the size of xmm registers, |
195 | * otherwise it will result in an illegal memory read (seg-fault) |
196 | * due to protected memory. */ |
197 | && IMPLICATION(isa == sse41 && dat_tag_ == nchw, HW >= 4) |
198 | && !is_superset(isa, avx512_core); |
199 | |
200 | const int jit_max_local_size = 5; // bigger size triggers too big code size |
201 | const bool args_ok_within = true && desc()->alg_kind == lrn_within_channel |
202 | && desc()->local_size <= (jit_max_local_size <= MAX_LOCAL_SIZE |
203 | ? jit_max_local_size |
204 | : MAX_LOCAL_SIZE) |
205 | && src_d.dims()[2] >= desc()->local_size |
206 | && src_d.dims()[3] >= desc()->local_size |
207 | && IMPLICATION(d_type == data_type::bf16, mayiuse(avx512_core)) |
208 | && IMPLICATION(d_type == data_type::f16, mayiuse(avx512_core_fp16)) |
209 | && (is_superset(isa, avx512_core) ? one_of(dat_tag_, nhwc, nChw16c) |
210 | : one_of(dat_tag_, nhwc, nChw8c)); |
211 | |
212 | const auto status |
213 | = args_ok_across || args_ok_within ? success : unimplemented; |
214 | |
215 | if (desc()->prop_kind == forward_training && status == success) { |
216 | dims_t ws_dims = {MB(), C(), H(), 2 * W()}; |
217 | memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, dat_tag_); |
218 | } |
219 | |
220 | return status; |
221 | } |
222 | |
223 | template <cpu_isa_t isa, data_type_t d_type> |
224 | jit_uni_lrn_bwd_t<isa, d_type>::jit_uni_lrn_bwd_t(const pd_t *apd) |
225 | : primitive_t(apd) |
226 | , ker_(nullptr) |
227 | , ker_first_(nullptr) |
228 | , ker_last_(nullptr) {} |
229 | |
230 | template <cpu_isa_t isa, data_type_t d_type> |
231 | jit_uni_lrn_bwd_t<isa, d_type>::~jit_uni_lrn_bwd_t() = default; |
232 | |
233 | template <cpu_isa_t isa, data_type_t d_type> |
234 | status_t jit_uni_lrn_bwd_t<isa, d_type>::init(engine_t *engine) { |
235 | using namespace alg_kind; |
236 | const int C = pd()->C(); |
237 | const int H = pd()->H(); |
238 | const int W = pd()->W(); |
239 | const int &ls = pd()->desc()->local_size; |
240 | const auto &ak = pd()->desc()->alg_kind; |
241 | const int ndims = memory_desc_wrapper(pd()->src_md()).ndims(); |
242 | const float A = pd()->desc()->lrn_alpha / compute_n_summands(ls, ndims, ak); |
243 | const float &B = pd()->desc()->lrn_beta; |
244 | const auto &dat_tag = pd()->dat_tag_; |
245 | |
246 | if (one_of(dat_tag, nhwc, nChw8c, nChw16c) && ak == lrn_within_channel) { |
247 | ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>( |
248 | within_config_t(H, W, C, ls, dat_tag), A, B); |
249 | } else { |
250 | int use_h_parallelism = 0; // XXX |
251 | if (C / VECTOR_LENGTH == 1) { |
252 | ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>( |
253 | nchw8c_across_t(H, W, 3), A, B, use_h_parallelism); |
254 | } else { |
255 | ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>( |
256 | nchw8c_across_t(H, W, 0), A, B, use_h_parallelism); |
257 | ker_first_ |
258 | = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>( |
259 | nchw8c_across_t(H, W, -1), A, B, use_h_parallelism); |
260 | ker_last_ |
261 | = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>( |
262 | nchw8c_across_t(H, W, +1), A, B, use_h_parallelism); |
263 | } |
264 | } |
265 | CHECK(ker_->create_kernel()); |
266 | if (ker_first_) CHECK(ker_first_->create_kernel()); |
267 | if (ker_last_) CHECK(ker_last_->create_kernel()); |
268 | return status::success; |
269 | } |
270 | |
271 | template <cpu_isa_t isa, data_type_t d_type> |
272 | status_t jit_uni_lrn_bwd_t<isa, d_type>::execute_backward( |
273 | const exec_ctx_t &ctx) const { |
274 | status_t status = status::success; |
275 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
276 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
277 | auto ws = CTX_IN_MEM(const data_t *, DNNL_ARG_WORKSPACE); |
278 | auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
279 | CHECK(status); |
280 | |
281 | const int N = pd()->MB(); |
282 | const int C = pd()->C(); |
283 | const int H = pd()->H(); |
284 | const int W = pd()->W(); |
285 | const auto ak = pd()->desc()->alg_kind; |
286 | const auto &dat_tag = pd()->dat_tag_; |
287 | |
288 | static constexpr bool use_h_parallelism = false; // XXX |
289 | |
290 | const auto ker = ker_.get(); |
291 | const auto ker_first = ker_first_.get(); |
292 | const auto ker_last = ker_last_.get(); |
293 | const auto tensor_size = N * C * H * W; |
294 | |
295 | if (one_of(dat_tag, nhwc, nChw8c, nChw16c) |
296 | && ak == alg_kind::lrn_within_channel) { |
297 | parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c) { |
298 | const std::size_t offset = dat_tag == nhwc |
299 | ? n * H * W * C + c * VECTOR_LENGTH |
300 | : n * H * W * C + c * H * W * VECTOR_LENGTH; |
301 | jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset], |
302 | &ws[offset + tensor_size], &diff_src[offset]}; |
303 | (*ker)(&args); |
304 | }); |
305 | } else if (use_h_parallelism) { |
306 | parallel_nd(N, C / VECTOR_LENGTH, H, [&](dim_t n, dim_t c8, dim_t h) { |
307 | const std::size_t offset = n * C * H * W |
308 | + c8 * H * W * VECTOR_LENGTH + h * W * VECTOR_LENGTH; |
309 | jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset], |
310 | nullptr, &diff_src[offset]}; |
311 | if (C / VECTOR_LENGTH == 1) |
312 | (*ker)(&args); |
313 | else if (c8 == 0) |
314 | (*ker_first)(&args); |
315 | else if (c8 == C / VECTOR_LENGTH - 1) |
316 | (*ker_last)(&args); |
317 | else |
318 | (*ker)(&args); |
319 | }); |
320 | } else { |
321 | parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c8) { |
322 | const std::size_t offset |
323 | = n * C * H * W + c8 * H * W * VECTOR_LENGTH; |
324 | jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset], |
325 | nullptr, &diff_src[offset]}; |
326 | if (C / VECTOR_LENGTH == 1) |
327 | (*ker)(&args); |
328 | else if (c8 == 0) |
329 | (*ker_first)(&args); |
330 | else if (c8 == C / VECTOR_LENGTH - 1) |
331 | (*ker_last)(&args); |
332 | else |
333 | (*ker)(&args); |
334 | }); |
335 | } |
336 | return status::success; |
337 | } |
338 | |
339 | template <cpu_isa_t isa, data_type_t d_type> |
340 | status_t jit_uni_lrn_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) { |
341 | using namespace prop_kind; |
342 | using namespace alg_kind; |
343 | |
344 | const memory_desc_wrapper src_d(src_md()); |
345 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
346 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
347 | |
348 | const bool ok = !is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory() |
349 | && utils::everyone_is(d_type, src_d.data_type(), |
350 | diff_src_d.data_type(), diff_dst_d.data_type()) |
351 | && src_d.ndims() == 4 && attr()->has_default_values() |
352 | && set_default_formats_common() && src_d == diff_dst_d |
353 | && diff_dst_d == diff_src_d && src_d.dims()[1] % VECTOR_LENGTH == 0 |
354 | && src_d.dims()[1] >= 2 * VECTOR_LENGTH && desc()->lrn_beta == 0.75; |
355 | if (!ok) return unimplemented; |
356 | |
357 | dat_tag_ = memory_desc_matches_one_of_tag( |
358 | *src_md(), nChw16c, nChw8c, nchw, nhwc); |
359 | |
360 | const dims_t ws_dims = {MB(), C(), H(), 2 * W()}; |
361 | memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, dat_tag_); |
362 | |
363 | if (!compare_ws(hint_fwd_pd_)) return unimplemented; |
364 | |
365 | const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels |
366 | && desc()->local_size == 5 && utils::one_of(dat_tag_, nChw8c) |
367 | && everyone_is(data_type::f32, src_d.data_type()) |
368 | && !is_superset(isa, avx512_core); |
369 | |
370 | const int jit_max_local_size = 5; // bigger size triggers too big code size |
371 | const bool args_ok_within = true && desc()->alg_kind == lrn_within_channel |
372 | && desc()->local_size <= (jit_max_local_size <= MAX_LOCAL_SIZE |
373 | ? jit_max_local_size |
374 | : MAX_LOCAL_SIZE) |
375 | && src_d.dims()[2] >= desc()->local_size |
376 | && src_d.dims()[3] >= desc()->local_size |
377 | && IMPLICATION(d_type == data_type::bf16, mayiuse(avx512_core)) |
378 | && IMPLICATION(d_type == data_type::f16, mayiuse(avx512_core_fp16)) |
379 | && (isa == avx512_core ? one_of(dat_tag_, nhwc, nChw16c) |
380 | : one_of(dat_tag_, nhwc, nChw8c)); |
381 | |
382 | return args_ok_across || args_ok_within ? success : unimplemented; |
383 | } |
384 | |
385 | template struct jit_uni_lrn_fwd_t<avx512_core, dnnl::impl::data_type::f32>; |
386 | template struct jit_uni_lrn_fwd_t<avx512_core, dnnl::impl::data_type::bf16>; |
387 | template struct jit_uni_lrn_fwd_t<avx512_core_fp16, dnnl::impl::data_type::f16>; |
388 | template struct jit_uni_lrn_fwd_t<avx2, dnnl::impl::data_type::f32>; |
389 | template struct jit_uni_lrn_fwd_t<sse41, dnnl::impl::data_type::f32>; |
390 | template struct jit_uni_lrn_bwd_t<avx512_core, dnnl::impl::data_type::f32>; |
391 | template struct jit_uni_lrn_bwd_t<avx512_core, dnnl::impl::data_type::bf16>; |
392 | template struct jit_uni_lrn_bwd_t<avx512_core_fp16, dnnl::impl::data_type::f16>; |
393 | template struct jit_uni_lrn_bwd_t<avx2, dnnl::impl::data_type::f32>; |
394 | |
395 | } // namespace x64 |
396 | } // namespace cpu |
397 | } // namespace impl |
398 | } // namespace dnnl |
399 | |
400 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
401 | |