1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/ref_lrn.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | namespace { |
31 | |
32 | using acc_data_t = float; |
33 | |
34 | inline acc_data_t fast_negative_powf(acc_data_t omega, acc_data_t beta) { |
35 | acc_data_t Y; |
36 | /* |
37 | * Y = omega^(-3/4) = |
38 | * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) |
39 | * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) |
40 | * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) |
41 | * = sqrtf(1.0f / sqrtf(omega) / omega) |
42 | * = sqrtf(1.0f / (sqrtf(omega) * omega)) |
43 | */ |
44 | if (beta == 0.75f) { |
45 | Y = sqrtf(1.0f / (sqrtf(omega) * omega)); |
46 | } else { |
47 | Y = 1.0f / powf(omega, beta); |
48 | } |
49 | return Y; |
50 | }; |
51 | } // namespace |
52 | |
53 | // Forward LRN formula: |
54 | // y_i = x_i * (k + a / n * Sum:j [x_j^2])^-b, where |
55 | // k, a(alpha), b(beta), n(local_size) - lrn hyperparameters; |
56 | // j - kernel points, j in [i - n/2, i + n/2] for ACROSS, 2d-shape for WITHIN; |
57 | |
58 | template <impl::data_type_t d_type> |
59 | template <impl::format_tag_t tag> |
60 | status_t ref_lrn_fwd_t<d_type>::execute_forward(const exec_ctx_t &ctx) const { |
61 | using namespace alg_kind; |
62 | using namespace format_tag; |
63 | |
64 | status_t status = status::success; |
65 | |
66 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
67 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
68 | CHECK(status); |
69 | |
70 | const memory_desc_wrapper data_d(pd()->src_md()); |
71 | |
72 | const dim_t C = pd()->C(); |
73 | const dim_t D = pd()->D(); |
74 | const dim_t H = pd()->H(); |
75 | const dim_t W = pd()->W(); |
76 | const auto stride_mb = data_d.blocking_desc().strides[0]; |
77 | const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; |
78 | static constexpr dim_t blksize = tag == nChw16c ? 16 : 8; |
79 | const auto ndims = data_d.ndims(); |
80 | |
81 | auto compute_n_summands = [&](dim_t size) { |
82 | if (across_channels) { |
83 | return size; |
84 | } else { // within_channel |
85 | dim_t n_summands = 1; |
86 | for (auto d = ndims - 2; d > 0; --d) |
87 | n_summands *= size; |
88 | return n_summands; |
89 | } |
90 | }; |
91 | |
92 | const acc_data_t alpha = static_cast<acc_data_t>(pd()->desc()->lrn_alpha); |
93 | const acc_data_t beta = static_cast<acc_data_t>(pd()->desc()->lrn_beta); |
94 | const acc_data_t k = static_cast<acc_data_t>(pd()->desc()->lrn_k); |
95 | const dim_t size = pd()->desc()->local_size; |
96 | const dim_t half_size = (size - 1) / 2; |
97 | const dim_t summands = compute_n_summands(size); |
98 | |
99 | auto data_off = [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) -> dim_t { |
100 | switch (tag) { |
101 | case nChw16c: |
102 | case nChw8c: |
103 | return mb * stride_mb + (c / blksize) * H * W * blksize |
104 | + h * W * blksize + w * blksize + c % blksize; |
105 | case nchw: return mb * stride_mb + c * H * W + h * W + w; |
106 | case nhwc: return mb * stride_mb + h * W * C + w * C + c; |
107 | default: |
108 | if (ndims >= 5) return data_d.off(mb, c, d, h, w); |
109 | if (ndims >= 4) return data_d.off(mb, c, h, w); |
110 | if (ndims >= 3) return data_d.off(mb, c, w); |
111 | return data_d.off(mb, c); |
112 | } |
113 | }; |
114 | |
115 | // pass by value due to icc170 and icc180 problem on KNL |
116 | auto ker = [=](data_t *d, dim_t mb, dim_t oc, dim_t od, dim_t oh, |
117 | dim_t ow) { |
118 | acc_data_t sum = 0; |
119 | if (across_channels) { |
120 | const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0); |
121 | const dim_t c_en = nstl::min(oc + half_size + 1, C); |
122 | |
123 | for (dim_t c = c_st; c < c_en; ++c) { |
124 | const acc_data_t s = src[data_off(mb, c, od, oh, ow)]; |
125 | sum += s * s; |
126 | } |
127 | } else { |
128 | dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0); |
129 | dim_t d_en = nstl::min(od + half_size + 1, D); |
130 | dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0); |
131 | dim_t h_en = nstl::min(oh + half_size + 1, H); |
132 | dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0); |
133 | dim_t w_en = nstl::min(ow + half_size + 1, W); |
134 | for_(dim_t d = d_st; d < d_en; ++d) |
135 | for_(dim_t h = h_st; h < h_en; ++h) |
136 | for (dim_t w = w_st; w < w_en; ++w) { |
137 | const acc_data_t s = src[data_off(mb, oc, d, h, w)]; |
138 | sum += s * s; |
139 | } |
140 | } |
141 | sum = k + alpha * sum / summands; |
142 | const acc_data_t s = src[data_off(mb, oc, od, oh, ow)]; |
143 | d[0] = static_cast<data_t>(s * fast_negative_powf(sum, beta)); |
144 | }; |
145 | |
146 | const dim_t MB = pd()->MB(); |
147 | if (tag == nChw16c || tag == nChw8c) { |
148 | parallel_nd(MB, utils::div_up(C, blksize), H, W, |
149 | [&](dim_t mb, dim_t c_blk, dim_t h, dim_t w) { |
150 | dim_t c = c_blk * blksize; |
151 | const dim_t off = mb * stride_mb + c * H * W |
152 | + (h * W + w) * blksize; |
153 | PRAGMA_OMP_SIMD() |
154 | for (dim_t cc = 0; cc < nstl::min(blksize, C - c); ++cc) |
155 | ker(&dst[off + cc], mb, c + cc, 0, h, w); |
156 | }); |
157 | } else if (tag == nhwc) { |
158 | parallel_nd(MB, H, W, C, [&](dim_t mb, dim_t h, dim_t w, dim_t c) { |
159 | const dim_t off = mb * stride_mb + h * W * C + w * C + c; |
160 | ker(&dst[off], mb, c, 0, h, w); |
161 | }); |
162 | } else { |
163 | parallel_nd(MB, C, D, H, W, |
164 | [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) { |
165 | const dim_t off = data_off(mb, c, d, h, w); |
166 | ker(&dst[off], mb, c, d, h, w); |
167 | }); |
168 | } |
169 | return status::success; |
170 | } |
171 | |
172 | // Backward LRN formula (refer to Forward LRN formula): |
173 | // Partial derivatives: |
174 | // dy_i/dx_j = - 2*a*b/n * x_i * O(i)^-b / O(i) * x_j, i != j |
175 | // O(i)^-b - 2*a*b/n * x_i * O(i)^-b / O(i) * x_j, i == j, where |
176 | // O(i) = (k + a / n * Sum:j [x_j^2]), j in [i - n/2, i + n/2]. Note: j depends |
177 | // on i, which means that O(i) may use more points than local_size. |
178 | // Now, z_i = Sum:k [dE/dy_k * dy_k/dx_j], where k in [i - n/2, i + n/2] |
179 | // for ACROSS. 2d-shape for WITHIN. |
180 | // Then, dE/dy_k = diffDst_k. Finally, |
181 | // z_i = Sum:k [dd_k * dy_k/dx_j] = A - B (code variables) = |
182 | // = dd_i * O(i)^-b - 2*a*b/n * x_i * Sum:k {O(k)^-b / O(k) * x_k * dd_k}; |
183 | |
184 | template <impl::data_type_t d_type> |
185 | template <dnnl_format_tag_t tag> |
186 | status_t ref_lrn_bwd_t<d_type>::execute_backward(const exec_ctx_t &ctx) const { |
187 | using namespace alg_kind; |
188 | using namespace format_tag; |
189 | |
190 | status_t status = status::success; |
191 | |
192 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
193 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
194 | auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
195 | CHECK(status); |
196 | |
197 | const memory_desc_wrapper data_d(pd()->src_md()); |
198 | |
199 | const dim_t C = pd()->C(); |
200 | const dim_t D = pd()->D(); |
201 | const dim_t H = pd()->H(); |
202 | const dim_t W = pd()->W(); |
203 | const auto stride_mb = data_d.blocking_desc().strides[0]; |
204 | const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; |
205 | static constexpr dim_t blksize = tag == nChw16c ? 16 : 8; |
206 | const auto ndims = data_d.ndims(); |
207 | |
208 | auto compute_n_summands = [&](dim_t size) { |
209 | if (across_channels) { |
210 | return size; |
211 | } else { // within_channel |
212 | dim_t n_summands = 1; |
213 | for (auto d = ndims - 2; d > 0; --d) |
214 | n_summands *= size; |
215 | return n_summands; |
216 | } |
217 | }; |
218 | |
219 | const acc_data_t alpha = static_cast<acc_data_t>(pd()->desc()->lrn_alpha); |
220 | const acc_data_t beta = static_cast<acc_data_t>(pd()->desc()->lrn_beta); |
221 | const acc_data_t k = static_cast<acc_data_t>(pd()->desc()->lrn_k); |
222 | const dim_t size = pd()->desc()->local_size; |
223 | const dim_t half_size = (size - 1) / 2; |
224 | const dim_t summands = compute_n_summands(size); |
225 | |
226 | auto data_off = [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) -> dim_t { |
227 | switch (tag) { |
228 | case nChw16c: |
229 | case nChw8c: |
230 | return mb * stride_mb + (c / blksize) * H * W * blksize |
231 | + h * W * blksize + w * blksize + c % blksize; |
232 | case nchw: return mb * stride_mb + c * H * W + h * W + w; |
233 | case nhwc: return mb * stride_mb + h * W * C + w * C + c; |
234 | default: |
235 | if (ndims >= 5) return data_d.off(mb, c, d, h, w); |
236 | if (ndims >= 4) return data_d.off(mb, c, h, w); |
237 | if (ndims >= 3) return data_d.off(mb, c, w); |
238 | return data_d.off(mb, c); |
239 | } |
240 | }; |
241 | |
242 | // pass by value due to icc170 and icc180 problem on KNL |
243 | auto get_omega = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
244 | acc_data_t sum = 0; |
245 | if (across_channels) { |
246 | const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0); |
247 | const dim_t c_en = nstl::min(oc + half_size + 1, C); |
248 | |
249 | for (dim_t c = c_st; c < c_en; ++c) { |
250 | const acc_data_t s = src[data_off(mb, c, od, oh, ow)]; |
251 | sum += s * s; |
252 | } |
253 | } else { |
254 | dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0); |
255 | dim_t d_en = nstl::min(od + half_size + 1, D); |
256 | dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0); |
257 | dim_t h_en = nstl::min(oh + half_size + 1, H); |
258 | dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0); |
259 | dim_t w_en = nstl::min(ow + half_size + 1, W); |
260 | for_(dim_t d = d_st; d < d_en; ++d) |
261 | for_(dim_t h = h_st; h < h_en; ++h) |
262 | for (dim_t w = w_st; w < w_en; ++w) { |
263 | const acc_data_t s = src[data_off(mb, oc, d, h, w)]; |
264 | sum += s * s; |
265 | } |
266 | } |
267 | return (acc_data_t)(k + alpha * sum / summands); |
268 | }; |
269 | |
270 | // pass by value due to icc170 and icc180 problem on KNL |
271 | auto ker = [=](data_t *d, dim_t mb, dim_t oc, dim_t od, dim_t oh, |
272 | dim_t ow) { |
273 | acc_data_t A = 0, B = 0; |
274 | if (across_channels) { |
275 | const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0); |
276 | const dim_t c_en = nstl::min(oc + half_size + 1, C); |
277 | |
278 | for (dim_t c = c_st; c < c_en; c++) { |
279 | const auto off = data_off(mb, c, od, oh, ow); |
280 | const acc_data_t omega = get_omega(mb, c, od, oh, ow); |
281 | const acc_data_t omega_in_beta |
282 | = fast_negative_powf(omega, beta); |
283 | const acc_data_t tmp |
284 | = omega_in_beta * (acc_data_t)diff_dst[off]; |
285 | if (c == oc) A = tmp; |
286 | B += (src[off] * tmp / omega); |
287 | } |
288 | } else { |
289 | dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0); |
290 | dim_t d_en = nstl::min(od + half_size + 1, D); |
291 | dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0); |
292 | dim_t h_en = nstl::min(oh + half_size + 1, H); |
293 | dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0); |
294 | dim_t w_en = nstl::min(ow + half_size + 1, W); |
295 | for_(dim_t d = d_st; d < d_en; ++d) |
296 | for_(dim_t h = h_st; h < h_en; ++h) |
297 | for (dim_t w = w_st; w < w_en; ++w) { |
298 | const auto off = data_off(mb, oc, d, h, w); |
299 | const acc_data_t omega = get_omega(mb, oc, d, h, w); |
300 | const acc_data_t omega_in_beta |
301 | = fast_negative_powf(omega, beta); |
302 | const acc_data_t tmp |
303 | = omega_in_beta * (acc_data_t)diff_dst[off]; |
304 | if (d == od && h == oh && w == ow) A = tmp; |
305 | B += (src[off] * tmp / omega); |
306 | } |
307 | } |
308 | const auto off = data_off(mb, oc, od, oh, ow); |
309 | B *= (2.0f * alpha * beta * src[off] / summands); |
310 | *d = static_cast<data_t>(A - B); |
311 | }; |
312 | |
313 | const dim_t MB = pd()->MB(); |
314 | if (tag == nChw16c || tag == nChw8c) { |
315 | parallel_nd(MB, utils::div_up(C, blksize), H, W, |
316 | [&](dim_t mb, dim_t c_blk, dim_t h, dim_t w) { |
317 | dim_t c = c_blk * blksize; |
318 | const dim_t off = mb * stride_mb + c * H * W |
319 | + (h * W + w) * blksize; |
320 | PRAGMA_OMP_SIMD() |
321 | for (dim_t cc = 0; cc < nstl::min(blksize, C - c); ++cc) |
322 | ker(&diff_src[off + cc], mb, c + cc, 0, h, w); |
323 | }); |
324 | } else if (tag == nhwc) { |
325 | parallel_nd(MB, H, W, C, [&](dim_t mb, dim_t h, dim_t w, dim_t c) { |
326 | const dim_t off = mb * stride_mb + h * W * C + w * C + c; |
327 | ker(&diff_src[off], mb, c, 0, h, w); |
328 | }); |
329 | } else { |
330 | parallel_nd(MB, C, D, H, W, |
331 | [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) { |
332 | const dim_t off = data_off(mb, c, d, h, w); |
333 | ker(&diff_src[off], mb, c, d, h, w); |
334 | }); |
335 | } |
336 | return status::success; |
337 | } |
338 | |
339 | template status_t |
340 | ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nChw16c>( |
341 | const exec_ctx_t &ctx) const; |
342 | template status_t |
343 | ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nChw8c>( |
344 | const exec_ctx_t &ctx) const; |
345 | template status_t |
346 | ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nchw>( |
347 | const exec_ctx_t &ctx) const; |
348 | template status_t |
349 | ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nhwc>( |
350 | const exec_ctx_t &ctx) const; |
351 | template status_t |
352 | ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::any>( |
353 | const exec_ctx_t &ctx) const; |
354 | template status_t |
355 | ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nChw16c>( |
356 | const exec_ctx_t &ctx) const; |
357 | template status_t |
358 | ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nChw8c>( |
359 | const exec_ctx_t &ctx) const; |
360 | template status_t |
361 | ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nchw>( |
362 | const exec_ctx_t &ctx) const; |
363 | template status_t |
364 | ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nhwc>( |
365 | const exec_ctx_t &ctx) const; |
366 | template status_t |
367 | ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::any>( |
368 | const exec_ctx_t &ctx) const; |
369 | |
370 | template status_t |
371 | ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nChw16c>( |
372 | const exec_ctx_t &ctx) const; |
373 | template status_t |
374 | ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nChw8c>( |
375 | const exec_ctx_t &ctx) const; |
376 | template status_t |
377 | ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nchw>( |
378 | const exec_ctx_t &ctx) const; |
379 | template status_t |
380 | ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nhwc>( |
381 | const exec_ctx_t &ctx) const; |
382 | template status_t |
383 | ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::any>( |
384 | const exec_ctx_t &ctx) const; |
385 | template status_t |
386 | ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nChw16c>( |
387 | const exec_ctx_t &ctx) const; |
388 | template status_t |
389 | ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nChw8c>( |
390 | const exec_ctx_t &ctx) const; |
391 | template status_t |
392 | ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nchw>( |
393 | const exec_ctx_t &ctx) const; |
394 | template status_t |
395 | ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nhwc>( |
396 | const exec_ctx_t &ctx) const; |
397 | template status_t |
398 | ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::any>( |
399 | const exec_ctx_t &ctx) const; |
400 | |
401 | template status_t |
402 | ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nChw16c>( |
403 | const exec_ctx_t &ctx) const; |
404 | template status_t |
405 | ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nChw8c>( |
406 | const exec_ctx_t &ctx) const; |
407 | template status_t |
408 | ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nchw>( |
409 | const exec_ctx_t &ctx) const; |
410 | template status_t |
411 | ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nhwc>( |
412 | const exec_ctx_t &ctx) const; |
413 | template status_t |
414 | ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::any>( |
415 | const exec_ctx_t &ctx) const; |
416 | template status_t |
417 | ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nChw16c>( |
418 | const exec_ctx_t &ctx) const; |
419 | template status_t |
420 | ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nChw8c>( |
421 | const exec_ctx_t &ctx) const; |
422 | template status_t |
423 | ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nchw>( |
424 | const exec_ctx_t &ctx) const; |
425 | template status_t |
426 | ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nhwc>( |
427 | const exec_ctx_t &ctx) const; |
428 | template status_t |
429 | ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::any>( |
430 | const exec_ctx_t &ctx) const; |
431 | |
432 | } // namespace cpu |
433 | } // namespace impl |
434 | } // namespace dnnl |
435 | |
436 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
437 | |