1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18
19#include <stdio.h>
20#include <stdlib.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "utils/parallel.hpp"
25
26#include "dnn_types.hpp"
27#include "dnnl_common.hpp"
28#include "dnnl_memory.hpp"
29
30#include "binary/binary.hpp"
31#include "eltwise/eltwise.hpp"
32
33namespace binary {
34
35//TODO: Consider filling with powers of 2 for division to avoid rounding errors
36int fill_mem(int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
37 bool only_positive_values = false, bool only_integer_values = false) {
38 const auto nelems = mem_fp.nelems();
39 if (nelems == 0) return OK;
40
41 const auto dt = mem_dt.dt();
42 const int range = 16;
43 const int f_min = dt == dnnl_u8 ? 0 : -range / 2;
44
45 benchdnn_parallel_nd(nelems, [&](int64_t i) {
46 const int64_t gen = (12 * i + 5 * input_idx + 16) % (range + 1);
47 const float scale = only_integer_values ? 1.f : 1.25f;
48 float value = (f_min + gen) * scale;
49 if (only_positive_values) value = fabs(value);
50 // Remove zeroes in src1 to avoid division by zero
51 if (input_idx == 1 && value == 0.0f) value = 1.0f;
52 mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
53 });
54
55 SAFE(mem_dt.reorder(mem_fp), WARN);
56
57 return OK;
58}
59
60int setup_binary_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args,
61 std::vector<dnn_mem_t> &mem_dt, std::vector<dnn_mem_t> &mem_fp,
62 bool only_positive_values, bool only_integer_values) {
63 // TODO: currently run-time dimensions are not supported in binary post-op.
64 // To add a support two ways are possible: 1) add query support to the
65 // library and extract expected md from pd; 2) pass a vector of pre-defined
66 // (no run-time values) of `po_md`s and create memories from them in case
67 // the library will lack of query mechanism.
68 auto const_attr_po = query_post_ops(pd);
69 auto po_len = dnnl_post_ops_len(const_attr_po);
70 for (int idx = 0; idx < po_len; ++idx) {
71 auto kind = dnnl_post_ops_get_kind(const_attr_po, idx);
72 if (kind != dnnl_binary) continue;
73
74 int po_idx = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
75 const auto &po_md = query_md(pd, po_idx);
76
77 // Following call can not be executed if po_md has runtime dimension due
78 // to undefined size.
79 mem_fp.emplace_back(po_md, dnnl_f32, tag::abx, get_cpu_engine());
80 mem_dt.emplace_back(po_md, get_test_engine());
81 args.push_back(po_idx);
82 fill_mem(po_idx, mem_dt.back(), mem_fp.back(), only_positive_values,
83 only_integer_values);
84 }
85 return OK;
86}
87
88dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
89 const prb_t *prb = init_pd_args.prb;
90
91 std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d(
92 prb->n_inputs());
93
94 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
95 const dims_t &i_vdims = prb->vdims[i_input];
96 src_d[i_input] = dnn_mem_t::init_md(prb->ndims, i_vdims.data(),
97 prb->sdt[i_input], prb->stag[i_input]);
98 }
99
100 auto dst_d = dnn_mem_t::init_md(
101 prb->ndims, prb->dst_dims.data(), prb->ddt, prb->dtag);
102
103 dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg);
104
105 attr_args_t attr_args;
106 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data());
107 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
108 create_dnnl_attr(prb->attr, attr_args));
109
110 DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create(&init_pd_args.pd,
111 init_pd_args.engine, alg, src_d[0], src_d[1], dst_d, dnnl_attr));
112
113 return dnnl_success;
114}
115
116void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
117 std::vector<dnnl_data_type_t> dts = {prb->sdt[0], prb->sdt[1], prb->ddt};
118 skip_unimplemented_data_type(dts, prb->dir, res);
119 skip_unimplemented_arg_scale(prb->attr, res);
120
121 // N.B: Adding this for gpu as cfg is not supported in POST-OPS
122 if (is_gpu()) {
123 bool have_post_ops = !prb->attr.post_ops.is_def();
124 bool is_bf16u8 = (dts[0] == dnnl_bf16 && dts[1] == dnnl_bf16
125 && dts[2] == dnnl_u8);
126 if (is_bf16u8 && have_post_ops) {
127 res->state = SKIPPED, res->reason = DATA_TYPE_NOT_SUPPORTED;
128 return;
129 }
130 }
131}
132
133void skip_invalid_prb(const prb_t *prb, res_t *res) {
134 const bool is_sum = prb->attr.post_ops.find(alg_t::SUM) >= 0;
135 bool bcast_src0 = false;
136 for (int d = 0; d < prb->ndims; ++d)
137 if (prb->vdims[0][d] != prb->vdims[1][d] && prb->vdims[0][d] == 1) {
138 bcast_src0 = true;
139 break;
140 }
141
142 // In case src0 is broadcasted into src1, it means that src0 has smaller
143 // memory footprint and doing sum post-op or in-place will cause a crash.
144 if (bcast_src0 && (prb->inplace || is_sum)) {
145 res->state = SKIPPED, res->reason = INVALID_CASE;
146 return;
147 }
148
149 // See `skip_invalid_inplace` for details.
150 if (prb->inplace) {
151 if (is_sum) {
152 res->state = SKIPPED, res->reason = INVALID_CASE;
153 return;
154 }
155
156 skip_invalid_inplace(
157 res, prb->sdt[0], prb->ddt, prb->stag[0], prb->dtag);
158 if (res->state == SKIPPED) return;
159 }
160}
161
162void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
163 const args_t &ref_args) {
164 cmp.set_threshold(epsilon_dt(prb->ddt));
165 // Since lambda is called when stack is unavailable, need to capture `prb`
166 // by value to avoid using dangling references.
167 const auto binary_add_check
168 = [prb](const compare::compare_t::driver_check_func_args_t &args) {
169 // fp16 result can slightly mismatch for division due to
170 // difference in backends implementations.
171 return prb->alg == alg_t::DIV
172 ? args.diff < epsilon_dt(args.dt)
173 : false;
174 };
175 cmp.set_driver_check_function(binary_add_check);
176
177 const std::vector<alg_t> cmp_alg = {
178 alg_t::GE, alg_t::GT, alg_t::LE, alg_t::LT, alg_t::EQ, alg_t::NE};
179 const bool is_cmp = std::any_of(
180 cmp_alg.cbegin(), cmp_alg.cend(), [&](const alg_t alg) {
181 return (prb->alg == alg) || prb->attr.post_ops.find(alg) >= 0;
182 });
183
184 if (is_cmp) cmp.set_zero_trust_percent(99.f);
185}
186
187int doit(const prb_t *prb, res_t *res) {
188 if (bench_mode == LIST) return res->state = LISTED, OK;
189
190 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
191 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
192 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
193
194 auto const_pd = query_pd(prim);
195
196 const auto &src0_md = query_md(const_pd, DNNL_ARG_SRC_0);
197 const auto &src1_md = query_md(const_pd, DNNL_ARG_SRC_1);
198 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
199 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
200
201 const auto fp = dnnl_f32;
202 const auto tag = tag::abx;
203
204 const auto &test_engine = get_test_engine();
205 const auto &ref_engine = get_cpu_engine();
206
207 dnn_mem_t src0_fp(src0_md, fp, tag, ref_engine);
208 dnn_mem_t src0_dt(src0_md, test_engine);
209 SAFE(fill_mem(0, src0_dt, src0_fp), WARN);
210
211 dnn_mem_t src1_fp(src1_md, fp, tag, ref_engine);
212 dnn_mem_t src1_dt(src1_md, test_engine);
213 SAFE(fill_mem(1, src1_dt, src1_fp), WARN);
214
215 dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine);
216 dnn_mem_t dst_dt(dst_md, test_engine);
217 if (prb->attr.post_ops.find(alg_t::SUM) >= 0 || is_amd_gpu())
218 SAFE(fill_mem(2, dst_dt, dst_fp), WARN);
219
220 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
221 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
222 std::vector<int> binary_po_args;
223 SAFE(setup_binary_po(const_pd, binary_po_args, binary_po_dt, binary_po_fp),
224 WARN);
225
226 args_t args, ref_args;
227
228 args.set(DNNL_ARG_SRC_0, src0_dt);
229 args.set(DNNL_ARG_SRC_1, src1_dt);
230 args.set(DNNL_ARG_DST, dst_dt);
231 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
232 args.set(binary_po_args, binary_po_dt);
233
234 dnn_mem_t input_scales_m0;
235 float scale0 = prb->attr.scales.get(DNNL_ARG_SRC_0).scale;
236 maybe_prepare_runtime_scales(
237 input_scales_m0, prb->attr.scales.get(DNNL_ARG_SRC_0), 1, &scale0);
238 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, input_scales_m0);
239 dnn_mem_t input_scales_m1;
240 float scale1 = prb->attr.scales.get(DNNL_ARG_SRC_1).scale;
241 maybe_prepare_runtime_scales(
242 input_scales_m1, prb->attr.scales.get(DNNL_ARG_SRC_1), 1, &scale1);
243 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, input_scales_m1);
244
245 SAFE(execute_and_wait(prim, args, res), WARN);
246
247 if (is_bench_mode(CORR)) {
248 ref_args.set(DNNL_ARG_SRC_0, src0_fp);
249 ref_args.set(DNNL_ARG_SRC_1, src1_fp);
250 ref_args.set(DNNL_ARG_DST, dst_fp);
251 ref_args.set(binary_po_args, binary_po_fp);
252
253 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
254 }
255
256 return measure_perf(prb->ctx_exe, res, prim, args);
257}
258
259} // namespace binary
260