1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/ref_eltwise.hpp"
25#include "cpu/simple_q10n.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31#define DATA_OFF(f, n, c, d, h, w) \
32 (ndims == 1) \
33 ? (f).off(n) \
34 : ((ndims == 2) ? (f).off(n, c) \
35 : ((ndims == 3) ? (f).off(n, c, w) \
36 : ((ndims == 4) ? (f).off( \
37 n, c, h, w) \
38 : (f).off(n, c, d, \
39 h, w))))
40
41template <data_type_t data_type>
42status_t ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded(
43 const exec_ctx_t &ctx) const {
44 status_t status = status::success;
45 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
46 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
47 CHECK(status);
48
49 const memory_desc_wrapper src_d(pd()->src_md());
50 const blocking_desc_t &blk = src_d.blocking_desc();
51 const dim_t block = blk.inner_blks[0];
52
53 const dim_t MB = pd()->MB();
54 const dim_t C = pd()->C() / block;
55 const dim_t C_PADDED = src_d.padded_dims()[1] / block;
56 const dim_t tail = pd()->C() % block;
57 const dim_t SP = pd()->D() * pd()->H() * pd()->W();
58 const auto alg_kind = pd()->desc()->alg_kind;
59 const float alpha = pd()->desc()->alpha;
60 const float beta = pd()->desc()->beta;
61
62 auto ker = [=](data_t &d, data_t s) {
63 float res = compute_eltwise_scalar_fwd(alg_kind, s, alpha, beta);
64 d = cpu::saturate_and_round<data_t>(res);
65 };
66
67 parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
68 auto d_off = (n * C_PADDED * SP + c * SP + sp) * block;
69 if (c < C) {
70 for (dim_t v = 0; v < block; v++)
71 ker(dst[d_off + v], src[d_off + v]);
72 } else {
73 for (dim_t v = 0; v < tail; v++)
74 ker(dst[d_off + v], src[d_off + v]);
75 }
76 });
77
78 return status::success;
79}
80
81template <data_type_t data_type>
82status_t ref_eltwise_fwd_t<data_type>::execute_forward_generic(
83 const exec_ctx_t &ctx) const {
84 /* fast return */
85 if (pd()->has_zero_dim_memory()) return status::success;
86
87 status_t status = status::success;
88 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
89 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
90 CHECK(status);
91
92 const memory_desc_wrapper src_d(pd()->src_md());
93
94 const dim_t MB = pd()->MB();
95 const dim_t C = pd()->C();
96 const dim_t D = pd()->D();
97 const dim_t H = pd()->H();
98 const dim_t W = pd()->W();
99 const auto alg_kind = pd()->desc()->alg_kind;
100 const float alpha = pd()->desc()->alpha;
101 const float beta = pd()->desc()->beta;
102 const int ndims = pd()->ndims();
103
104 parallel_nd(
105 MB, C, D, H, W, [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) {
106 auto data_p_off = DATA_OFF(src_d, n, c, d, h, w);
107 float res = compute_eltwise_scalar_fwd(
108 alg_kind, src[data_p_off], alpha, beta);
109 dim_t data_l_off = (((n * C + c) * D + d) * H + h) * W + w;
110
111 ref_post_ops_t::args_t args;
112 args.ctx = &ctx;
113 args.l_offset = data_l_off;
114 args.dst_md = pd()->dst_md();
115 ref_post_ops->execute(res, args);
116
117 dst[data_p_off] = cpu::saturate_and_round<data_t>(res);
118 });
119 return status::success;
120}
121
122template <data_type_t data_type>
123status_t ref_eltwise_fwd_t<data_type>::execute_forward_dense(
124 const exec_ctx_t &ctx) const {
125 status_t status = status::success;
126 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
127 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
128 CHECK(status);
129
130 const memory_desc_wrapper src_d(pd()->src_md());
131
132 const auto nelems = src_d.nelems(true);
133 const auto alg_kind = pd()->desc()->alg_kind;
134 const float alpha = pd()->desc()->alpha;
135 const float beta = pd()->desc()->beta;
136
137 src += src_d.offset0();
138 dst += src_d.offset0();
139
140 // a fast path for relu as the most popular activation
141 if (alg_kind == alg_kind::eltwise_relu && alpha == 0) {
142 parallel_nd(nelems, [&](dim_t e) {
143 float res = math::relu_fwd(src[e], alpha);
144 dst[e] = cpu::saturate_and_round<data_t>(res);
145 });
146 return status::success;
147 }
148
149 parallel_nd(nelems, [&](dim_t e) {
150 float res = compute_eltwise_scalar_fwd(alg_kind, src[e], alpha, beta);
151 dst[e] = cpu::saturate_and_round<data_t>(res);
152 });
153 return status::success;
154}
155
156template <data_type_t data_type>
157status_t ref_eltwise_bwd_t<data_type>::execute_backward_generic(
158 const exec_ctx_t &ctx) const {
159 /* fast return */
160 if (pd()->has_zero_dim_memory()) return status::success;
161
162 status_t status = status::success;
163 auto src = CTX_IN_MEM(
164 const data_t *, pd()->use_dst() ? DNNL_ARG_DST : DNNL_ARG_SRC);
165 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
166 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
167 CHECK(status);
168
169 const memory_desc_wrapper data_d(pd()->data_md());
170 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
171
172 const dim_t MB = pd()->MB();
173 const dim_t C = pd()->C();
174 const dim_t D = pd()->D();
175 const dim_t H = pd()->H();
176 const dim_t W = pd()->W();
177 const auto alg_kind = pd()->desc()->alg_kind;
178 const float alpha = pd()->desc()->alpha;
179 const float beta = pd()->desc()->beta;
180 const int ndims = pd()->ndims();
181
182 parallel_nd(
183 MB, C, D, H, W, [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) {
184 auto data_off = DATA_OFF(data_d, n, c, d, h, w);
185 auto diff_data_off = DATA_OFF(diff_data_d, n, c, d, h, w);
186 data_t s = src[data_off];
187 data_t dd = diff_dst[diff_data_off];
188 data_t &ds = diff_src[diff_data_off];
189 ds = compute_eltwise_scalar_bwd(alg_kind, dd, s, alpha, beta);
190 });
191 return status::success;
192}
193
194template <data_type_t data_type>
195status_t ref_eltwise_bwd_t<data_type>::execute_backward_dense(
196 const exec_ctx_t &ctx) const {
197 status_t status = status::success;
198 const void *src = pd()->use_dst() ? CTX_IN_MEM(const void *, DNNL_ARG_DST)
199 : CTX_IN_MEM(const void *, DNNL_ARG_SRC);
200 const void *diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
201 void *diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
202 CHECK(status);
203
204 const memory_desc_wrapper data_d(pd()->data_md());
205 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
206
207 const auto nelems = data_d.nelems(true);
208 const auto alg_kind = pd()->desc()->alg_kind;
209 const float alpha = pd()->desc()->alpha;
210 const float beta = pd()->desc()->beta;
211
212 if (data_type == data_type::f32) {
213 const float *src_ptr = static_cast<const float *>(src);
214 const float *diff_dst_ptr = static_cast<const float *>(diff_dst);
215 float *diff_src_ptr = static_cast<float *>(diff_src);
216
217 src_ptr += data_d.offset0();
218 diff_dst_ptr += diff_data_d.offset0();
219 diff_src_ptr += diff_data_d.offset0();
220
221 parallel(0, [&](const int ithr, const int nthr) {
222 dim_t start = 0, end = 0;
223 balance211(nelems, nthr, ithr, start, end);
224 if (start == end) return;
225
226 for (dim_t i = start; i < end; i++) {
227 diff_src_ptr[i] = compute_eltwise_scalar_bwd(
228 alg_kind, diff_dst_ptr[i], src_ptr[i], alpha, beta);
229 }
230 });
231 } else if (utils::one_of(data_type, data_type::bf16, data_type::f16)) {
232 const data_t *src_ptr = static_cast<const data_t *>(src);
233 const data_t *diff_dst_ptr = static_cast<const data_t *>(diff_dst);
234 data_t *diff_src_ptr = static_cast<data_t *>(diff_src);
235
236 src_ptr += data_d.offset0();
237 diff_dst_ptr += diff_data_d.offset0();
238 diff_src_ptr += diff_data_d.offset0();
239
240 using namespace memory_tracking::names;
241 auto scratchpad = ctx.get_scratchpad_grantor();
242 auto *src_f32 = scratchpad.template get<float>(key_eltwise_src);
243 auto *diff_dst_f32
244 = scratchpad.template get<float>(key_eltwise_diff_dst);
245
246 parallel(0, [&](const int ithr, const int nthr) {
247 dim_t start = 0, end = 0;
248 balance211(nelems, nthr, ithr, start, end);
249 if (start == end) return;
250
251 types::cvt_to_float(src_f32 + start, src_ptr + start, end - start);
252 types::cvt_to_float(
253 diff_dst_f32 + start, diff_dst_ptr + start, end - start);
254
255 for (dim_t i = start; i < end; i++) {
256 diff_dst_f32[i] = compute_eltwise_scalar_bwd(
257 alg_kind, diff_dst_f32[i], src_f32[i], alpha, beta);
258 }
259
260 types::cvt_from_float(
261 diff_src_ptr + start, diff_dst_f32 + start, end - start);
262 });
263 } else {
264 assert(!"unsupported data type");
265 }
266 return status::success;
267}
268
269template struct ref_eltwise_fwd_t<data_type::f32>;
270template struct ref_eltwise_fwd_t<data_type::bf16>;
271template struct ref_eltwise_fwd_t<data_type::f16>;
272template struct ref_eltwise_fwd_t<data_type::s32>;
273template struct ref_eltwise_fwd_t<data_type::s8>;
274template struct ref_eltwise_fwd_t<data_type::u8>;
275
276template struct ref_eltwise_bwd_t<data_type::f32>;
277template struct ref_eltwise_bwd_t<data_type::bf16>;
278template struct ref_eltwise_bwd_t<data_type::f16>;
279
280} // namespace cpu
281} // namespace impl
282} // namespace dnnl
283
284// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
285