1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | #include <random> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "utils/parallel.hpp" |
25 | |
26 | #include "dnnl_common.hpp" |
27 | #include "dnnl_memory.hpp" |
28 | |
29 | #include "binary/binary.hpp" |
30 | #include "eltwise/eltwise.hpp" |
31 | |
32 | namespace eltwise { |
33 | |
34 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
35 | const prb_t *prb = init_pd_args.prb; |
36 | const dir_t dir = init_pd_args.dir; |
37 | |
38 | auto src_d = dnn_mem_t::init_md( |
39 | prb->ndims, prb->dims.data(), prb->dt, prb->tag); |
40 | auto dst_d = dnn_mem_t::init_md( |
41 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
42 | |
43 | dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg); |
44 | |
45 | attr_args_t attr_args; |
46 | attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dims.data()); |
47 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
48 | create_dnnl_attr(prb->attr, attr_args)); |
49 | |
50 | if (dir & FLAG_FWD) { |
51 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
52 | : dnnl_forward_training; |
53 | |
54 | DNN_SAFE_STATUS(dnnl_eltwise_forward_primitive_desc_create( |
55 | &init_pd_args.pd, init_pd_args.engine, prop, alg, src_d, dst_d, |
56 | prb->alpha, prb->beta, dnnl_attr)); |
57 | } else { |
58 | auto diff_src_d = dnn_mem_t::init_md( |
59 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
60 | auto diff_dst_d = dnn_mem_t::init_md( |
61 | prb->ndims, prb->dims.data(), prb->dt, tag::any); |
62 | if (prb->use_dst()) // Need to create with proper tag |
63 | dst_d = dnn_mem_t::init_md( |
64 | prb->ndims, prb->dims.data(), prb->dt, prb->tag); |
65 | auto &data_d = prb->use_dst() ? dst_d : src_d; |
66 | |
67 | DNN_SAFE_STATUS(dnnl_eltwise_backward_primitive_desc_create( |
68 | &init_pd_args.pd, init_pd_args.engine, alg, diff_src_d, |
69 | diff_dst_d, data_d, prb->alpha, prb->beta, init_pd_args.hint, |
70 | dnnl_attr)); |
71 | } |
72 | |
73 | return dnnl_success; |
74 | } |
75 | |
76 | static bool check_abs_err(const prb_t *prb, const float &s, const float &trh) { |
77 | const float approx_machine_eps = 2 * epsilon_dt(dnnl_f32); |
78 | const float comp_err = approx_machine_eps / trh; |
79 | |
80 | switch (prb->alg) { |
81 | case alg_t::ELU: |
82 | case alg_t::ELU_DST: |
83 | // catch catastrophic cancellation when (exp(s) - 1), s < 0 and |
84 | // s is close to zero. |
85 | return (prb->dir & FLAG_FWD) && std::signbit(s) |
86 | && (fabsf(expf(s) - 1.f) <= comp_err); |
87 | case alg_t::GELU_TANH: { |
88 | // catch catastrophic cancellation |
89 | // (4.f is magic scale for f32) |
90 | const float sqrt_2_over_pi = 0.797884; |
91 | const float fitting_const = 0.044715; |
92 | float v = tanhf(sqrt_2_over_pi * s * (1 + fitting_const * s * s)); |
93 | float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s); |
94 | if (fabsf(1.f + v) <= comp_err) return true; |
95 | return (prb->dir & FLAG_BWD) && std::signbit(s) |
96 | && fabsf(1.f + s * (1.f - v) * dg) <= 4.f * comp_err; |
97 | } |
98 | case alg_t::GELU_ERF: { |
99 | // Catch catastrophic cancellation |
100 | // which occurs at large negative s. |
101 | // Factor 2 (in bwd) is to account for the fact that error is |
102 | // accumulated for each summand (except the 1) when they |
103 | // are of the same order of magnitude. |
104 | const float sqrt_2_over_2 = 0.707106769084930419921875f; |
105 | const float two_over_sqrt_pi = 1.12837922573089599609375f; |
106 | float v = s * sqrt_2_over_2; |
107 | if (prb->dir & FLAG_FWD) |
108 | return fabsf(1.f + erff(v)) <= comp_err; |
109 | else |
110 | return fabsf(1.f + erff(v) |
111 | + v * two_over_sqrt_pi * expf(-v * v)) |
112 | <= comp_err * 2; |
113 | } |
114 | case alg_t::TANH: |
115 | // catch catastrophic cancellation, which occurs when err in tanh(s) |
116 | // is high and tanh(s) is close to 1. |
117 | return (prb->dir & FLAG_BWD) && (1.f - tanhf(fabsf(s))) <= comp_err; |
118 | case alg_t::TANH_DST: // sse41 can't do fma |
119 | // catch catastrophic cancellation, which occurs when err in tanh(s) |
120 | // is high and tanh(s) is close to 1. |
121 | return (prb->dir & FLAG_BWD) && (1.f - s * s) <= comp_err; |
122 | case alg_t::SRELU: |
123 | // when `alpha * s` is negative, expf(alpha * s) -> 0 rapidly |
124 | // which leads to log1pf(expf(alpha * s)) -> 0 |
125 | // which leads to high relative error, |
126 | // while abs error is still low. |
127 | // (10.f is magic scale for bf16) |
128 | return (prb->dir & FLAG_FWD) && std::signbit(prb->alpha * s) |
129 | && log1pf(expf(prb->alpha * s)) <= 10.f * comp_err; |
130 | case alg_t::MISH: |
131 | // same situation like in SRELU |
132 | return (prb->dir & FLAG_FWD) && std::signbit(s) |
133 | && s * tanh(log1pf(expf(s))) <= 10.f * comp_err; |
134 | case alg_t::LOGISTIC: |
135 | // when s >= 4, logistic(s) -> 0 rapidly, which leads to high |
136 | // relative error of logistic(s) * (1 - logistic(s)) due to |
137 | // catastrohic cancellation. |
138 | return (prb->dir & FLAG_BWD) && !std::signbit(s) |
139 | && (1.f / (1.f + expf(s))) <= comp_err; |
140 | case alg_t::LOGISTIC_DST: |
141 | // when s = logistic(x) ~~ 1, it leads to high relative error of |
142 | // s * (1 - s) due to catastrohic cancellation. |
143 | return (prb->dir & FLAG_BWD) |
144 | && ((1 - s) <= comp_err || s <= comp_err); |
145 | case alg_t::SWISH: { |
146 | // catch cancellation happening when W(s) ~~ -1 in (1 + W(s)) |
147 | // formula part on backward. |
148 | const float alpha_s = prb->alpha * s; |
149 | return (prb->dir & FLAG_BWD) |
150 | && (alpha_s * (1.f - 1.f / (1.f + expf(-alpha_s))) |
151 | <= comp_err); |
152 | } |
153 | default: return false; |
154 | } |
155 | } |
156 | |
157 | float get_eltwise_threshold(dnnl_data_type_t dt, alg_t alg, bool is_fwd) { |
158 | // Tolerate only rounding error (1 ulp) for other than fp32 precisions. |
159 | float trh = dt == dnnl_f32 ? 4e-6 : epsilon_dt(dt); |
160 | // Tolerate bigger compute errors for complex algorithms. |
161 | const bool alg_has_higher_tolerance = alg == alg_t::GELU_TANH |
162 | || alg == alg_t::ELU || alg == alg_t::SWISH || alg == alg_t::TANH |
163 | || alg == alg_t::SRELU || alg == alg_t::MISH || alg == alg_t::LOG |
164 | || ((alg == alg_t::ELU_DST || alg == alg_t::TANH_DST) && is_fwd); |
165 | if (dt == dnnl_f32 && alg_has_higher_tolerance) trh = 4e-5; |
166 | return trh; |
167 | } |
168 | |
169 | static float get_eltwise_zero_trust_percent(const prb_t *prb) { |
170 | float ztp = 65.f; // default for eltwise due to filling. |
171 | switch (prb->alg) { |
172 | case alg_t::LINEAR: |
173 | if (prb->alpha == 0) ztp = 100.f; |
174 | break; |
175 | case alg_t::CLIP: |
176 | case alg_t::CLIP_V2: |
177 | case alg_t::CLIP_V2_DST: |
178 | if ((prb->alpha == 0 && prb->beta == 0) || (prb->dir & FLAG_BWD)) |
179 | ztp = 100.f; |
180 | break; |
181 | case alg_t::POW: |
182 | if (prb->alpha == 0 || ((prb->dir & FLAG_BWD) && prb->beta == 0)) |
183 | ztp = 100.f; |
184 | break; |
185 | default: break; |
186 | } |
187 | // Integral data types with small float values will produce most zeros. |
188 | // u8 with negative alpha will produce only zeros. |
189 | if (is_integral_dt(prb->dt)) ztp = 100.f; |
190 | return ztp; |
191 | } |
192 | |
193 | int fill_data(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt, |
194 | dnn_mem_t &mem_fp) { |
195 | const auto nelems = mem_fp.nelems(); |
196 | if (nelems == 0) return OK; |
197 | |
198 | /* Do fixed partitioning to have same filling for any number of threads */ |
199 | const int64_t n_chunks = 16; |
200 | const int64_t chunk_size = div_up(nelems, n_chunks); |
201 | const bool is_log = prb->alg == alg_t::LOG; |
202 | |
203 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
204 | int64_t idx_start = idx_chunk * chunk_size; |
205 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
206 | // Note 1: we use a different seed for each chunk to avoid |
207 | // repeating patterns. We could use discard(idx_start) too but |
208 | // we avoid it for two reasons: |
209 | // a. it has a complexity in O(idx_start). |
210 | // b. igen and fgen below might require more than 1 sample |
211 | // per idx, so the we cannot deterministically compute the |
212 | // number of states we need to discard |
213 | // Note 2: We also advance the state to avoid having only |
214 | // small values as first chunk input. The +1 is necessary to |
215 | // avoid generating zeros in first chunk. |
216 | // Note 3: we multiply by kind + 1 to have different values in |
217 | // src/dst and diff_dst. The +1 is to avoid 0 again. |
218 | std::minstd_rand msr((idx_start + 1) * (kind + 1)); |
219 | msr.discard(1); |
220 | std::uniform_int_distribution<> igen(0, 10); |
221 | // TODO: 0.09 due to log impl doesn't give good accuracy in 0.99 points |
222 | std::uniform_real_distribution<> fgen(0.f, 0.09f); |
223 | |
224 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
225 | const int64_t num_of_generation_variants |
226 | = 13 + (2 * static_cast<int64_t>(is_log)); |
227 | float value = FLT_MAX; |
228 | switch (idx % num_of_generation_variants) { |
229 | case 0: value = (float)igen(msr); break; // [0-10] pos |
230 | case 1: value = -(float)igen(msr); break; // [0-10] neg |
231 | case 2: value = fgen(msr); break; // [0.-0.1) pos |
232 | case 3: value = -fgen(msr); break; // [0.-0.1) neg |
233 | case 4: value = 10 * (float)igen(msr); break; // [0-100] pos |
234 | case 5: value = -10 * (float)igen(msr); break; // [0-100] neg |
235 | case 6: value = 10.f * fgen(msr); break; // [0.-1.) pos |
236 | case 7: value = -10.f * fgen(msr); break; // [0.-1.) neg |
237 | case 8: |
238 | value = 88.f + 10.f * fgen(msr); |
239 | break; // values close to logf(FLT_MAX) for exp alg testing |
240 | case 9: |
241 | value = 22.f + 10.f * fgen(msr); |
242 | break; // values close to logf(FLT_MAX)/4.0 for bwd mish alg testing |
243 | case 10: |
244 | value = 44.f + 10.f * fgen(msr); |
245 | break; // values close to logf(FLT_MAX)/2.0 for fwd mish alg testing |
246 | case 11: value = prb->alpha; break; // `x = alpha` corner cases |
247 | case 12: value = prb->beta; break; // `x = beta` corner cases |
248 | case 13: value = INFINITY; break; // used in LOG alg only |
249 | case 14: value = -INFINITY; break; // used in LOG alg only |
250 | } |
251 | value = round_to_nearest_representable(prb->dt, value); |
252 | |
253 | // Hack: -0 may lead to different sign in the answer since input |
254 | // passes through simple reorder which converts -0 into +0. |
255 | if (value == -0.f) value = 0.f; |
256 | |
257 | mem_fp.set_elem(idx, value); |
258 | } |
259 | }); |
260 | |
261 | SAFE(mem_dt.reorder(mem_fp), WARN); |
262 | |
263 | return OK; |
264 | } |
265 | |
266 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
267 | skip_unimplemented_data_type({prb->dt}, prb->dir, res); |
268 | skip_unimplemented_sum_po(prb->attr, res); |
269 | } |
270 | |
271 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
272 | bool is_invalid = false; |
273 | switch (prb->alg) { |
274 | case alg_t::CLIP: |
275 | case alg_t::CLIP_V2: |
276 | case alg_t::CLIP_V2_DST: is_invalid = prb->beta < prb->alpha; break; |
277 | case alg_t::ELU_DST: |
278 | case alg_t::RELU_DST: is_invalid = prb->alpha < 0; break; |
279 | case alg_t::ROUND: |
280 | is_invalid = prb->dt != dnnl_f32 || prb->dir & FLAG_BWD; |
281 | break; |
282 | default: break; |
283 | }; |
284 | if (is_invalid) { |
285 | res->state = SKIPPED, res->reason = INVALID_CASE; |
286 | return; |
287 | } |
288 | |
289 | // Since source is needed for non-use-dst algorithms, it is incorrect to |
290 | // let forward path overwrite it. |
291 | is_invalid = (prb->dir & FLAG_BWD) && !prb->use_dst() && prb->inplace; |
292 | if (is_invalid) { |
293 | res->state = SKIPPED, res->reason = INVALID_CASE; |
294 | return; |
295 | } |
296 | |
297 | // See `skip_invalid_inplace` for details. |
298 | if (prb->inplace) { |
299 | skip_invalid_inplace(res, prb->dt, prb->dt, prb->tag, prb->tag); |
300 | if (res->state == SKIPPED) return; |
301 | } |
302 | } |
303 | |
304 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
305 | const args_t &ref_args) { |
306 | const float trh |
307 | = get_eltwise_threshold(prb->dt, prb->alg, prb->dir & FLAG_FWD); |
308 | cmp.set_threshold(trh); |
309 | |
310 | cmp.set_zero_trust_percent(get_eltwise_zero_trust_percent(prb)); |
311 | |
312 | // Since lambda is called when stack is unavailable, need to capture `prb` |
313 | // by value to avoid using dangling references. |
314 | const auto eltwise_add_check = |
315 | [&, prb](const compare::compare_t::driver_check_func_args_t &args) { |
316 | // Some algorithms require absolute value comparison for inputs |
317 | // where catastrophic cancellation may happen. |
318 | const auto &src = ref_args.find(DNNL_ARG_SRC); |
319 | const auto &dst = ref_args.find(DNNL_ARG_DST); |
320 | const auto &source |
321 | = ((prb->dir & FLAG_BWD) && prb->use_dst()) ? dst : src; |
322 | const float s = source.get_elem(args.idx); |
323 | if (check_abs_err(prb, s, args.trh)) |
324 | return args.diff <= args.trh; |
325 | if (prb->attr.post_ops.binary_index() != -1) |
326 | return args.diff <= args.trh; |
327 | return false; |
328 | }; |
329 | cmp.set_driver_check_function(eltwise_add_check); |
330 | } |
331 | |
332 | int doit(const prb_t *prb, res_t *res) { |
333 | if (bench_mode == LIST) return res->state = LISTED, OK; |
334 | |
335 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
336 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
337 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
338 | |
339 | auto const_fpd = query_pd(prim); |
340 | |
341 | const auto &src_md = query_md(const_fpd, DNNL_ARG_SRC); |
342 | const auto &dst_md = query_md(const_fpd, DNNL_ARG_DST); |
343 | const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD); |
344 | |
345 | const auto &test_engine = get_test_engine(); |
346 | const auto &ref_engine = get_cpu_engine(); |
347 | |
348 | dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine); |
349 | dnn_mem_t src_dt(src_md, test_engine); |
350 | |
351 | dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine); |
352 | dnn_mem_t placeholder_dst_dt; |
353 | if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); } |
354 | dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt; |
355 | |
356 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
357 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
358 | std::vector<int> binary_po_args; |
359 | SAFE(binary::setup_binary_po( |
360 | const_fpd, binary_po_args, binary_po_dt, binary_po_fp), |
361 | WARN); |
362 | |
363 | dnn_mem_t d_dst_dt, placeholder_d_src_dt; |
364 | |
365 | SAFE(fill_data(prb, SRC, src_dt, src_fp), WARN); |
366 | |
367 | args_t args, ref_args; |
368 | |
369 | args.set(DNNL_ARG_SRC, src_dt); |
370 | args.set(DNNL_ARG_DST, dst_dt); |
371 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
372 | args.set(binary_po_args, binary_po_dt); |
373 | |
374 | SAFE(execute_and_wait(prim, args, res), WARN); |
375 | |
376 | if (prb->dir & FLAG_FWD) { |
377 | if (is_bench_mode(CORR)) { |
378 | ref_args.set(DNNL_ARG_SRC, src_fp); |
379 | ref_args.set(DNNL_ARG_DST, dst_fp); |
380 | ref_args.set(binary_po_args, binary_po_fp); |
381 | |
382 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
383 | } |
384 | } |
385 | |
386 | if (prb->dir & FLAG_BWD) { |
387 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim; |
388 | SAFE(init_prim(prb->ctx_init, tmp_prim, init_pd, prb, res, FLAG_BWD, |
389 | const_fpd), |
390 | WARN); |
391 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
392 | prim.reset(tmp_prim.release()); |
393 | |
394 | auto const_bpd = query_pd(prim); |
395 | |
396 | const auto &d_dst_md = query_md(const_bpd, DNNL_ARG_DIFF_DST); |
397 | const auto &d_src_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC); |
398 | const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD); |
399 | |
400 | dnn_mem_t d_dst_fp |
401 | = dnn_mem_t(d_dst_md, dnnl_f32, tag::abx, ref_engine); |
402 | d_dst_dt = dnn_mem_t(d_dst_md, test_engine); |
403 | |
404 | dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference |
405 | if (!prb->inplace) { |
406 | placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine); |
407 | } |
408 | dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt; |
409 | |
410 | scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine); |
411 | |
412 | SAFE(fill_data(prb, DST, d_dst_dt, d_dst_fp), WARN); |
413 | |
414 | args.clear(); |
415 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
416 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
417 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
418 | if (prb->use_dst()) { |
419 | args.set(DNNL_ARG_DST, dst_dt); |
420 | } else { |
421 | args.set(DNNL_ARG_SRC, src_dt); |
422 | } |
423 | |
424 | SAFE(execute_and_wait(prim, args, res), WARN); |
425 | |
426 | if (is_bench_mode(CORR)) { |
427 | ref_args.set(DNNL_ARG_SRC, src_fp); |
428 | ref_args.set(DNNL_ARG_DST, dst_fp); |
429 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
430 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
431 | |
432 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
433 | } |
434 | } |
435 | |
436 | return measure_perf(prb->ctx_exe, res, prim, args); |
437 | } |
438 | |
439 | } // namespace eltwise |
440 | |