1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/ref_eltwise.hpp" |
25 | #include "cpu/simple_q10n.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | #define DATA_OFF(f, n, c, d, h, w) \ |
32 | (ndims == 1) \ |
33 | ? (f).off(n) \ |
34 | : ((ndims == 2) ? (f).off(n, c) \ |
35 | : ((ndims == 3) ? (f).off(n, c, w) \ |
36 | : ((ndims == 4) ? (f).off( \ |
37 | n, c, h, w) \ |
38 | : (f).off(n, c, d, \ |
39 | h, w)))) |
40 | |
41 | template <data_type_t data_type> |
42 | status_t ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded( |
43 | const exec_ctx_t &ctx) const { |
44 | status_t status = status::success; |
45 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
46 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
47 | CHECK(status); |
48 | |
49 | const memory_desc_wrapper src_d(pd()->src_md()); |
50 | const blocking_desc_t &blk = src_d.blocking_desc(); |
51 | const dim_t block = blk.inner_blks[0]; |
52 | |
53 | const dim_t MB = pd()->MB(); |
54 | const dim_t C = pd()->C() / block; |
55 | const dim_t C_PADDED = src_d.padded_dims()[1] / block; |
56 | const dim_t tail = pd()->C() % block; |
57 | const dim_t SP = pd()->D() * pd()->H() * pd()->W(); |
58 | const auto alg_kind = pd()->desc()->alg_kind; |
59 | const float alpha = pd()->desc()->alpha; |
60 | const float beta = pd()->desc()->beta; |
61 | |
62 | auto ker = [=](data_t &d, data_t s) { |
63 | float res = compute_eltwise_scalar_fwd(alg_kind, s, alpha, beta); |
64 | d = cpu::saturate_and_round<data_t>(res); |
65 | }; |
66 | |
67 | parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) { |
68 | auto d_off = (n * C_PADDED * SP + c * SP + sp) * block; |
69 | if (c < C) { |
70 | for (dim_t v = 0; v < block; v++) |
71 | ker(dst[d_off + v], src[d_off + v]); |
72 | } else { |
73 | for (dim_t v = 0; v < tail; v++) |
74 | ker(dst[d_off + v], src[d_off + v]); |
75 | } |
76 | }); |
77 | |
78 | return status::success; |
79 | } |
80 | |
81 | template <data_type_t data_type> |
82 | status_t ref_eltwise_fwd_t<data_type>::execute_forward_generic( |
83 | const exec_ctx_t &ctx) const { |
84 | /* fast return */ |
85 | if (pd()->has_zero_dim_memory()) return status::success; |
86 | |
87 | status_t status = status::success; |
88 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
89 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
90 | CHECK(status); |
91 | |
92 | const memory_desc_wrapper src_d(pd()->src_md()); |
93 | |
94 | const dim_t MB = pd()->MB(); |
95 | const dim_t C = pd()->C(); |
96 | const dim_t D = pd()->D(); |
97 | const dim_t H = pd()->H(); |
98 | const dim_t W = pd()->W(); |
99 | const auto alg_kind = pd()->desc()->alg_kind; |
100 | const float alpha = pd()->desc()->alpha; |
101 | const float beta = pd()->desc()->beta; |
102 | const int ndims = pd()->ndims(); |
103 | |
104 | parallel_nd( |
105 | MB, C, D, H, W, [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) { |
106 | auto data_p_off = DATA_OFF(src_d, n, c, d, h, w); |
107 | float res = compute_eltwise_scalar_fwd( |
108 | alg_kind, src[data_p_off], alpha, beta); |
109 | dim_t data_l_off = (((n * C + c) * D + d) * H + h) * W + w; |
110 | |
111 | ref_post_ops_t::args_t args; |
112 | args.ctx = &ctx; |
113 | args.l_offset = data_l_off; |
114 | args.dst_md = pd()->dst_md(); |
115 | ref_post_ops->execute(res, args); |
116 | |
117 | dst[data_p_off] = cpu::saturate_and_round<data_t>(res); |
118 | }); |
119 | return status::success; |
120 | } |
121 | |
122 | template <data_type_t data_type> |
123 | status_t ref_eltwise_fwd_t<data_type>::execute_forward_dense( |
124 | const exec_ctx_t &ctx) const { |
125 | status_t status = status::success; |
126 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
127 | auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
128 | CHECK(status); |
129 | |
130 | const memory_desc_wrapper src_d(pd()->src_md()); |
131 | |
132 | const auto nelems = src_d.nelems(true); |
133 | const auto alg_kind = pd()->desc()->alg_kind; |
134 | const float alpha = pd()->desc()->alpha; |
135 | const float beta = pd()->desc()->beta; |
136 | |
137 | src += src_d.offset0(); |
138 | dst += src_d.offset0(); |
139 | |
140 | // a fast path for relu as the most popular activation |
141 | if (alg_kind == alg_kind::eltwise_relu && alpha == 0) { |
142 | parallel_nd(nelems, [&](dim_t e) { |
143 | float res = math::relu_fwd(src[e], alpha); |
144 | dst[e] = cpu::saturate_and_round<data_t>(res); |
145 | }); |
146 | return status::success; |
147 | } |
148 | |
149 | parallel_nd(nelems, [&](dim_t e) { |
150 | float res = compute_eltwise_scalar_fwd(alg_kind, src[e], alpha, beta); |
151 | dst[e] = cpu::saturate_and_round<data_t>(res); |
152 | }); |
153 | return status::success; |
154 | } |
155 | |
156 | template <data_type_t data_type> |
157 | status_t ref_eltwise_bwd_t<data_type>::execute_backward_generic( |
158 | const exec_ctx_t &ctx) const { |
159 | /* fast return */ |
160 | if (pd()->has_zero_dim_memory()) return status::success; |
161 | |
162 | status_t status = status::success; |
163 | auto src = CTX_IN_MEM( |
164 | const data_t *, pd()->use_dst() ? DNNL_ARG_DST : DNNL_ARG_SRC); |
165 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
166 | auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
167 | CHECK(status); |
168 | |
169 | const memory_desc_wrapper data_d(pd()->data_md()); |
170 | const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); |
171 | |
172 | const dim_t MB = pd()->MB(); |
173 | const dim_t C = pd()->C(); |
174 | const dim_t D = pd()->D(); |
175 | const dim_t H = pd()->H(); |
176 | const dim_t W = pd()->W(); |
177 | const auto alg_kind = pd()->desc()->alg_kind; |
178 | const float alpha = pd()->desc()->alpha; |
179 | const float beta = pd()->desc()->beta; |
180 | const int ndims = pd()->ndims(); |
181 | |
182 | parallel_nd( |
183 | MB, C, D, H, W, [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) { |
184 | auto data_off = DATA_OFF(data_d, n, c, d, h, w); |
185 | auto diff_data_off = DATA_OFF(diff_data_d, n, c, d, h, w); |
186 | data_t s = src[data_off]; |
187 | data_t dd = diff_dst[diff_data_off]; |
188 | data_t &ds = diff_src[diff_data_off]; |
189 | ds = compute_eltwise_scalar_bwd(alg_kind, dd, s, alpha, beta); |
190 | }); |
191 | return status::success; |
192 | } |
193 | |
194 | template <data_type_t data_type> |
195 | status_t ref_eltwise_bwd_t<data_type>::execute_backward_dense( |
196 | const exec_ctx_t &ctx) const { |
197 | status_t status = status::success; |
198 | const void *src = pd()->use_dst() ? CTX_IN_MEM(const void *, DNNL_ARG_DST) |
199 | : CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
200 | const void *diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
201 | void *diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); |
202 | CHECK(status); |
203 | |
204 | const memory_desc_wrapper data_d(pd()->data_md()); |
205 | const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); |
206 | |
207 | const auto nelems = data_d.nelems(true); |
208 | const auto alg_kind = pd()->desc()->alg_kind; |
209 | const float alpha = pd()->desc()->alpha; |
210 | const float beta = pd()->desc()->beta; |
211 | |
212 | if (data_type == data_type::f32) { |
213 | const float *src_ptr = static_cast<const float *>(src); |
214 | const float *diff_dst_ptr = static_cast<const float *>(diff_dst); |
215 | float *diff_src_ptr = static_cast<float *>(diff_src); |
216 | |
217 | src_ptr += data_d.offset0(); |
218 | diff_dst_ptr += diff_data_d.offset0(); |
219 | diff_src_ptr += diff_data_d.offset0(); |
220 | |
221 | parallel(0, [&](const int ithr, const int nthr) { |
222 | dim_t start = 0, end = 0; |
223 | balance211(nelems, nthr, ithr, start, end); |
224 | if (start == end) return; |
225 | |
226 | for (dim_t i = start; i < end; i++) { |
227 | diff_src_ptr[i] = compute_eltwise_scalar_bwd( |
228 | alg_kind, diff_dst_ptr[i], src_ptr[i], alpha, beta); |
229 | } |
230 | }); |
231 | } else if (utils::one_of(data_type, data_type::bf16, data_type::f16)) { |
232 | const data_t *src_ptr = static_cast<const data_t *>(src); |
233 | const data_t *diff_dst_ptr = static_cast<const data_t *>(diff_dst); |
234 | data_t *diff_src_ptr = static_cast<data_t *>(diff_src); |
235 | |
236 | src_ptr += data_d.offset0(); |
237 | diff_dst_ptr += diff_data_d.offset0(); |
238 | diff_src_ptr += diff_data_d.offset0(); |
239 | |
240 | using namespace memory_tracking::names; |
241 | auto scratchpad = ctx.get_scratchpad_grantor(); |
242 | auto *src_f32 = scratchpad.template get<float>(key_eltwise_src); |
243 | auto *diff_dst_f32 |
244 | = scratchpad.template get<float>(key_eltwise_diff_dst); |
245 | |
246 | parallel(0, [&](const int ithr, const int nthr) { |
247 | dim_t start = 0, end = 0; |
248 | balance211(nelems, nthr, ithr, start, end); |
249 | if (start == end) return; |
250 | |
251 | types::cvt_to_float(src_f32 + start, src_ptr + start, end - start); |
252 | types::cvt_to_float( |
253 | diff_dst_f32 + start, diff_dst_ptr + start, end - start); |
254 | |
255 | for (dim_t i = start; i < end; i++) { |
256 | diff_dst_f32[i] = compute_eltwise_scalar_bwd( |
257 | alg_kind, diff_dst_f32[i], src_f32[i], alpha, beta); |
258 | } |
259 | |
260 | types::cvt_from_float( |
261 | diff_src_ptr + start, diff_dst_f32 + start, end - start); |
262 | }); |
263 | } else { |
264 | assert(!"unsupported data type" ); |
265 | } |
266 | return status::success; |
267 | } |
268 | |
269 | template struct ref_eltwise_fwd_t<data_type::f32>; |
270 | template struct ref_eltwise_fwd_t<data_type::bf16>; |
271 | template struct ref_eltwise_fwd_t<data_type::f16>; |
272 | template struct ref_eltwise_fwd_t<data_type::s32>; |
273 | template struct ref_eltwise_fwd_t<data_type::s8>; |
274 | template struct ref_eltwise_fwd_t<data_type::u8>; |
275 | |
276 | template struct ref_eltwise_bwd_t<data_type::f32>; |
277 | template struct ref_eltwise_bwd_t<data_type::bf16>; |
278 | template struct ref_eltwise_bwd_t<data_type::f16>; |
279 | |
280 | } // namespace cpu |
281 | } // namespace impl |
282 | } // namespace dnnl |
283 | |
284 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
285 | |