1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18#include <random>
19#include <stdio.h>
20#include <stdlib.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "utils/parallel.hpp"
25
26#include "dnnl_common.hpp"
27#include "dnnl_memory.hpp"
28
29#include "binary/binary.hpp"
30#include "eltwise/eltwise.hpp"
31
32namespace eltwise {
33
34dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
35 const prb_t *prb = init_pd_args.prb;
36 const dir_t dir = init_pd_args.dir;
37
38 auto src_d = dnn_mem_t::init_md(
39 prb->ndims, prb->dims.data(), prb->dt, prb->tag);
40 auto dst_d = dnn_mem_t::init_md(
41 prb->ndims, prb->dims.data(), prb->dt, tag::any);
42
43 dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg);
44
45 attr_args_t attr_args;
46 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dims.data());
47 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
48 create_dnnl_attr(prb->attr, attr_args));
49
50 if (dir & FLAG_FWD) {
51 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
52 : dnnl_forward_training;
53
54 DNN_SAFE_STATUS(dnnl_eltwise_forward_primitive_desc_create(
55 &init_pd_args.pd, init_pd_args.engine, prop, alg, src_d, dst_d,
56 prb->alpha, prb->beta, dnnl_attr));
57 } else {
58 auto diff_src_d = dnn_mem_t::init_md(
59 prb->ndims, prb->dims.data(), prb->dt, tag::any);
60 auto diff_dst_d = dnn_mem_t::init_md(
61 prb->ndims, prb->dims.data(), prb->dt, tag::any);
62 if (prb->use_dst()) // Need to create with proper tag
63 dst_d = dnn_mem_t::init_md(
64 prb->ndims, prb->dims.data(), prb->dt, prb->tag);
65 auto &data_d = prb->use_dst() ? dst_d : src_d;
66
67 DNN_SAFE_STATUS(dnnl_eltwise_backward_primitive_desc_create(
68 &init_pd_args.pd, init_pd_args.engine, alg, diff_src_d,
69 diff_dst_d, data_d, prb->alpha, prb->beta, init_pd_args.hint,
70 dnnl_attr));
71 }
72
73 return dnnl_success;
74}
75
76static bool check_abs_err(const prb_t *prb, const float &s, const float &trh) {
77 const float approx_machine_eps = 2 * epsilon_dt(dnnl_f32);
78 const float comp_err = approx_machine_eps / trh;
79
80 switch (prb->alg) {
81 case alg_t::ELU:
82 case alg_t::ELU_DST:
83 // catch catastrophic cancellation when (exp(s) - 1), s < 0 and
84 // s is close to zero.
85 return (prb->dir & FLAG_FWD) && std::signbit(s)
86 && (fabsf(expf(s) - 1.f) <= comp_err);
87 case alg_t::GELU_TANH: {
88 // catch catastrophic cancellation
89 // (4.f is magic scale for f32)
90 const float sqrt_2_over_pi = 0.797884;
91 const float fitting_const = 0.044715;
92 float v = tanhf(sqrt_2_over_pi * s * (1 + fitting_const * s * s));
93 float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s);
94 if (fabsf(1.f + v) <= comp_err) return true;
95 return (prb->dir & FLAG_BWD) && std::signbit(s)
96 && fabsf(1.f + s * (1.f - v) * dg) <= 4.f * comp_err;
97 }
98 case alg_t::GELU_ERF: {
99 // Catch catastrophic cancellation
100 // which occurs at large negative s.
101 // Factor 2 (in bwd) is to account for the fact that error is
102 // accumulated for each summand (except the 1) when they
103 // are of the same order of magnitude.
104 const float sqrt_2_over_2 = 0.707106769084930419921875f;
105 const float two_over_sqrt_pi = 1.12837922573089599609375f;
106 float v = s * sqrt_2_over_2;
107 if (prb->dir & FLAG_FWD)
108 return fabsf(1.f + erff(v)) <= comp_err;
109 else
110 return fabsf(1.f + erff(v)
111 + v * two_over_sqrt_pi * expf(-v * v))
112 <= comp_err * 2;
113 }
114 case alg_t::TANH:
115 // catch catastrophic cancellation, which occurs when err in tanh(s)
116 // is high and tanh(s) is close to 1.
117 return (prb->dir & FLAG_BWD) && (1.f - tanhf(fabsf(s))) <= comp_err;
118 case alg_t::TANH_DST: // sse41 can't do fma
119 // catch catastrophic cancellation, which occurs when err in tanh(s)
120 // is high and tanh(s) is close to 1.
121 return (prb->dir & FLAG_BWD) && (1.f - s * s) <= comp_err;
122 case alg_t::SRELU:
123 // when `alpha * s` is negative, expf(alpha * s) -> 0 rapidly
124 // which leads to log1pf(expf(alpha * s)) -> 0
125 // which leads to high relative error,
126 // while abs error is still low.
127 // (10.f is magic scale for bf16)
128 return (prb->dir & FLAG_FWD) && std::signbit(prb->alpha * s)
129 && log1pf(expf(prb->alpha * s)) <= 10.f * comp_err;
130 case alg_t::MISH:
131 // same situation like in SRELU
132 return (prb->dir & FLAG_FWD) && std::signbit(s)
133 && s * tanh(log1pf(expf(s))) <= 10.f * comp_err;
134 case alg_t::LOGISTIC:
135 // when s >= 4, logistic(s) -> 0 rapidly, which leads to high
136 // relative error of logistic(s) * (1 - logistic(s)) due to
137 // catastrohic cancellation.
138 return (prb->dir & FLAG_BWD) && !std::signbit(s)
139 && (1.f / (1.f + expf(s))) <= comp_err;
140 case alg_t::LOGISTIC_DST:
141 // when s = logistic(x) ~~ 1, it leads to high relative error of
142 // s * (1 - s) due to catastrohic cancellation.
143 return (prb->dir & FLAG_BWD)
144 && ((1 - s) <= comp_err || s <= comp_err);
145 case alg_t::SWISH: {
146 // catch cancellation happening when W(s) ~~ -1 in (1 + W(s))
147 // formula part on backward.
148 const float alpha_s = prb->alpha * s;
149 return (prb->dir & FLAG_BWD)
150 && (alpha_s * (1.f - 1.f / (1.f + expf(-alpha_s)))
151 <= comp_err);
152 }
153 default: return false;
154 }
155}
156
157float get_eltwise_threshold(dnnl_data_type_t dt, alg_t alg, bool is_fwd) {
158 // Tolerate only rounding error (1 ulp) for other than fp32 precisions.
159 float trh = dt == dnnl_f32 ? 4e-6 : epsilon_dt(dt);
160 // Tolerate bigger compute errors for complex algorithms.
161 const bool alg_has_higher_tolerance = alg == alg_t::GELU_TANH
162 || alg == alg_t::ELU || alg == alg_t::SWISH || alg == alg_t::TANH
163 || alg == alg_t::SRELU || alg == alg_t::MISH || alg == alg_t::LOG
164 || ((alg == alg_t::ELU_DST || alg == alg_t::TANH_DST) && is_fwd);
165 if (dt == dnnl_f32 && alg_has_higher_tolerance) trh = 4e-5;
166 return trh;
167}
168
169static float get_eltwise_zero_trust_percent(const prb_t *prb) {
170 float ztp = 65.f; // default for eltwise due to filling.
171 switch (prb->alg) {
172 case alg_t::LINEAR:
173 if (prb->alpha == 0) ztp = 100.f;
174 break;
175 case alg_t::CLIP:
176 case alg_t::CLIP_V2:
177 case alg_t::CLIP_V2_DST:
178 if ((prb->alpha == 0 && prb->beta == 0) || (prb->dir & FLAG_BWD))
179 ztp = 100.f;
180 break;
181 case alg_t::POW:
182 if (prb->alpha == 0 || ((prb->dir & FLAG_BWD) && prb->beta == 0))
183 ztp = 100.f;
184 break;
185 default: break;
186 }
187 // Integral data types with small float values will produce most zeros.
188 // u8 with negative alpha will produce only zeros.
189 if (is_integral_dt(prb->dt)) ztp = 100.f;
190 return ztp;
191}
192
193int fill_data(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
194 dnn_mem_t &mem_fp) {
195 const auto nelems = mem_fp.nelems();
196 if (nelems == 0) return OK;
197
198 /* Do fixed partitioning to have same filling for any number of threads */
199 const int64_t n_chunks = 16;
200 const int64_t chunk_size = div_up(nelems, n_chunks);
201 const bool is_log = prb->alg == alg_t::LOG;
202
203 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
204 int64_t idx_start = idx_chunk * chunk_size;
205 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
206 // Note 1: we use a different seed for each chunk to avoid
207 // repeating patterns. We could use discard(idx_start) too but
208 // we avoid it for two reasons:
209 // a. it has a complexity in O(idx_start).
210 // b. igen and fgen below might require more than 1 sample
211 // per idx, so the we cannot deterministically compute the
212 // number of states we need to discard
213 // Note 2: We also advance the state to avoid having only
214 // small values as first chunk input. The +1 is necessary to
215 // avoid generating zeros in first chunk.
216 // Note 3: we multiply by kind + 1 to have different values in
217 // src/dst and diff_dst. The +1 is to avoid 0 again.
218 std::minstd_rand msr((idx_start + 1) * (kind + 1));
219 msr.discard(1);
220 std::uniform_int_distribution<> igen(0, 10);
221 // TODO: 0.09 due to log impl doesn't give good accuracy in 0.99 points
222 std::uniform_real_distribution<> fgen(0.f, 0.09f);
223
224 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
225 const int64_t num_of_generation_variants
226 = 13 + (2 * static_cast<int64_t>(is_log));
227 float value = FLT_MAX;
228 switch (idx % num_of_generation_variants) {
229 case 0: value = (float)igen(msr); break; // [0-10] pos
230 case 1: value = -(float)igen(msr); break; // [0-10] neg
231 case 2: value = fgen(msr); break; // [0.-0.1) pos
232 case 3: value = -fgen(msr); break; // [0.-0.1) neg
233 case 4: value = 10 * (float)igen(msr); break; // [0-100] pos
234 case 5: value = -10 * (float)igen(msr); break; // [0-100] neg
235 case 6: value = 10.f * fgen(msr); break; // [0.-1.) pos
236 case 7: value = -10.f * fgen(msr); break; // [0.-1.) neg
237 case 8:
238 value = 88.f + 10.f * fgen(msr);
239 break; // values close to logf(FLT_MAX) for exp alg testing
240 case 9:
241 value = 22.f + 10.f * fgen(msr);
242 break; // values close to logf(FLT_MAX)/4.0 for bwd mish alg testing
243 case 10:
244 value = 44.f + 10.f * fgen(msr);
245 break; // values close to logf(FLT_MAX)/2.0 for fwd mish alg testing
246 case 11: value = prb->alpha; break; // `x = alpha` corner cases
247 case 12: value = prb->beta; break; // `x = beta` corner cases
248 case 13: value = INFINITY; break; // used in LOG alg only
249 case 14: value = -INFINITY; break; // used in LOG alg only
250 }
251 value = round_to_nearest_representable(prb->dt, value);
252
253 // Hack: -0 may lead to different sign in the answer since input
254 // passes through simple reorder which converts -0 into +0.
255 if (value == -0.f) value = 0.f;
256
257 mem_fp.set_elem(idx, value);
258 }
259 });
260
261 SAFE(mem_dt.reorder(mem_fp), WARN);
262
263 return OK;
264}
265
266void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
267 skip_unimplemented_data_type({prb->dt}, prb->dir, res);
268 skip_unimplemented_sum_po(prb->attr, res);
269}
270
271void skip_invalid_prb(const prb_t *prb, res_t *res) {
272 bool is_invalid = false;
273 switch (prb->alg) {
274 case alg_t::CLIP:
275 case alg_t::CLIP_V2:
276 case alg_t::CLIP_V2_DST: is_invalid = prb->beta < prb->alpha; break;
277 case alg_t::ELU_DST:
278 case alg_t::RELU_DST: is_invalid = prb->alpha < 0; break;
279 case alg_t::ROUND:
280 is_invalid = prb->dt != dnnl_f32 || prb->dir & FLAG_BWD;
281 break;
282 default: break;
283 };
284 if (is_invalid) {
285 res->state = SKIPPED, res->reason = INVALID_CASE;
286 return;
287 }
288
289 // Since source is needed for non-use-dst algorithms, it is incorrect to
290 // let forward path overwrite it.
291 is_invalid = (prb->dir & FLAG_BWD) && !prb->use_dst() && prb->inplace;
292 if (is_invalid) {
293 res->state = SKIPPED, res->reason = INVALID_CASE;
294 return;
295 }
296
297 // See `skip_invalid_inplace` for details.
298 if (prb->inplace) {
299 skip_invalid_inplace(res, prb->dt, prb->dt, prb->tag, prb->tag);
300 if (res->state == SKIPPED) return;
301 }
302}
303
304void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
305 const args_t &ref_args) {
306 const float trh
307 = get_eltwise_threshold(prb->dt, prb->alg, prb->dir & FLAG_FWD);
308 cmp.set_threshold(trh);
309
310 cmp.set_zero_trust_percent(get_eltwise_zero_trust_percent(prb));
311
312 // Since lambda is called when stack is unavailable, need to capture `prb`
313 // by value to avoid using dangling references.
314 const auto eltwise_add_check =
315 [&, prb](const compare::compare_t::driver_check_func_args_t &args) {
316 // Some algorithms require absolute value comparison for inputs
317 // where catastrophic cancellation may happen.
318 const auto &src = ref_args.find(DNNL_ARG_SRC);
319 const auto &dst = ref_args.find(DNNL_ARG_DST);
320 const auto &source
321 = ((prb->dir & FLAG_BWD) && prb->use_dst()) ? dst : src;
322 const float s = source.get_elem(args.idx);
323 if (check_abs_err(prb, s, args.trh))
324 return args.diff <= args.trh;
325 if (prb->attr.post_ops.binary_index() != -1)
326 return args.diff <= args.trh;
327 return false;
328 };
329 cmp.set_driver_check_function(eltwise_add_check);
330}
331
332int doit(const prb_t *prb, res_t *res) {
333 if (bench_mode == LIST) return res->state = LISTED, OK;
334
335 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
336 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
337 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
338
339 auto const_fpd = query_pd(prim);
340
341 const auto &src_md = query_md(const_fpd, DNNL_ARG_SRC);
342 const auto &dst_md = query_md(const_fpd, DNNL_ARG_DST);
343 const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD);
344
345 const auto &test_engine = get_test_engine();
346 const auto &ref_engine = get_cpu_engine();
347
348 dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine);
349 dnn_mem_t src_dt(src_md, test_engine);
350
351 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine);
352 dnn_mem_t placeholder_dst_dt;
353 if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); }
354 dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
355
356 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
357 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
358 std::vector<int> binary_po_args;
359 SAFE(binary::setup_binary_po(
360 const_fpd, binary_po_args, binary_po_dt, binary_po_fp),
361 WARN);
362
363 dnn_mem_t d_dst_dt, placeholder_d_src_dt;
364
365 SAFE(fill_data(prb, SRC, src_dt, src_fp), WARN);
366
367 args_t args, ref_args;
368
369 args.set(DNNL_ARG_SRC, src_dt);
370 args.set(DNNL_ARG_DST, dst_dt);
371 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
372 args.set(binary_po_args, binary_po_dt);
373
374 SAFE(execute_and_wait(prim, args, res), WARN);
375
376 if (prb->dir & FLAG_FWD) {
377 if (is_bench_mode(CORR)) {
378 ref_args.set(DNNL_ARG_SRC, src_fp);
379 ref_args.set(DNNL_ARG_DST, dst_fp);
380 ref_args.set(binary_po_args, binary_po_fp);
381
382 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
383 }
384 }
385
386 if (prb->dir & FLAG_BWD) {
387 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim;
388 SAFE(init_prim(prb->ctx_init, tmp_prim, init_pd, prb, res, FLAG_BWD,
389 const_fpd),
390 WARN);
391 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
392 prim.reset(tmp_prim.release());
393
394 auto const_bpd = query_pd(prim);
395
396 const auto &d_dst_md = query_md(const_bpd, DNNL_ARG_DIFF_DST);
397 const auto &d_src_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC);
398 const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD);
399
400 dnn_mem_t d_dst_fp
401 = dnn_mem_t(d_dst_md, dnnl_f32, tag::abx, ref_engine);
402 d_dst_dt = dnn_mem_t(d_dst_md, test_engine);
403
404 dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference
405 if (!prb->inplace) {
406 placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine);
407 }
408 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
409
410 scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine);
411
412 SAFE(fill_data(prb, DST, d_dst_dt, d_dst_fp), WARN);
413
414 args.clear();
415 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
416 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
417 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
418 if (prb->use_dst()) {
419 args.set(DNNL_ARG_DST, dst_dt);
420 } else {
421 args.set(DNNL_ARG_SRC, src_dt);
422 }
423
424 SAFE(execute_and_wait(prim, args, res), WARN);
425
426 if (is_bench_mode(CORR)) {
427 ref_args.set(DNNL_ARG_SRC, src_fp);
428 ref_args.set(DNNL_ARG_DST, dst_fp);
429 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
430 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
431
432 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
433 }
434 }
435
436 return measure_perf(prb->ctx_exe, res, prim, args);
437}
438
439} // namespace eltwise
440