1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "utils/parallel.hpp" |
25 | |
26 | #include "dnn_types.hpp" |
27 | #include "dnnl_common.hpp" |
28 | #include "dnnl_memory.hpp" |
29 | |
30 | #include "binary/binary.hpp" |
31 | #include "eltwise/eltwise.hpp" |
32 | |
33 | namespace binary { |
34 | |
35 | //TODO: Consider filling with powers of 2 for division to avoid rounding errors |
36 | int fill_mem(int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, |
37 | bool only_positive_values = false, bool only_integer_values = false) { |
38 | const auto nelems = mem_fp.nelems(); |
39 | if (nelems == 0) return OK; |
40 | |
41 | const auto dt = mem_dt.dt(); |
42 | const int range = 16; |
43 | const int f_min = dt == dnnl_u8 ? 0 : -range / 2; |
44 | |
45 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
46 | const int64_t gen = (12 * i + 5 * input_idx + 16) % (range + 1); |
47 | const float scale = only_integer_values ? 1.f : 1.25f; |
48 | float value = (f_min + gen) * scale; |
49 | if (only_positive_values) value = fabs(value); |
50 | // Remove zeroes in src1 to avoid division by zero |
51 | if (input_idx == 1 && value == 0.0f) value = 1.0f; |
52 | mem_fp.set_elem(i, round_to_nearest_representable(dt, value)); |
53 | }); |
54 | |
55 | SAFE(mem_dt.reorder(mem_fp), WARN); |
56 | |
57 | return OK; |
58 | } |
59 | |
60 | int setup_binary_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args, |
61 | std::vector<dnn_mem_t> &mem_dt, std::vector<dnn_mem_t> &mem_fp, |
62 | bool only_positive_values, bool only_integer_values) { |
63 | // TODO: currently run-time dimensions are not supported in binary post-op. |
64 | // To add a support two ways are possible: 1) add query support to the |
65 | // library and extract expected md from pd; 2) pass a vector of pre-defined |
66 | // (no run-time values) of `po_md`s and create memories from them in case |
67 | // the library will lack of query mechanism. |
68 | auto const_attr_po = query_post_ops(pd); |
69 | auto po_len = dnnl_post_ops_len(const_attr_po); |
70 | for (int idx = 0; idx < po_len; ++idx) { |
71 | auto kind = dnnl_post_ops_get_kind(const_attr_po, idx); |
72 | if (kind != dnnl_binary) continue; |
73 | |
74 | int po_idx = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1; |
75 | const auto &po_md = query_md(pd, po_idx); |
76 | |
77 | // Following call can not be executed if po_md has runtime dimension due |
78 | // to undefined size. |
79 | mem_fp.emplace_back(po_md, dnnl_f32, tag::abx, get_cpu_engine()); |
80 | mem_dt.emplace_back(po_md, get_test_engine()); |
81 | args.push_back(po_idx); |
82 | fill_mem(po_idx, mem_dt.back(), mem_fp.back(), only_positive_values, |
83 | only_integer_values); |
84 | } |
85 | return OK; |
86 | } |
87 | |
88 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
89 | const prb_t *prb = init_pd_args.prb; |
90 | |
91 | std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d( |
92 | prb->n_inputs()); |
93 | |
94 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) { |
95 | const dims_t &i_vdims = prb->vdims[i_input]; |
96 | src_d[i_input] = dnn_mem_t::init_md(prb->ndims, i_vdims.data(), |
97 | prb->sdt[i_input], prb->stag[i_input]); |
98 | } |
99 | |
100 | auto dst_d = dnn_mem_t::init_md( |
101 | prb->ndims, prb->dst_dims.data(), prb->ddt, prb->dtag); |
102 | |
103 | dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg); |
104 | |
105 | attr_args_t attr_args; |
106 | attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data()); |
107 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
108 | create_dnnl_attr(prb->attr, attr_args)); |
109 | |
110 | DNN_SAFE_STATUS(dnnl_binary_primitive_desc_create(&init_pd_args.pd, |
111 | init_pd_args.engine, alg, src_d[0], src_d[1], dst_d, dnnl_attr)); |
112 | |
113 | return dnnl_success; |
114 | } |
115 | |
116 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
117 | std::vector<dnnl_data_type_t> dts = {prb->sdt[0], prb->sdt[1], prb->ddt}; |
118 | skip_unimplemented_data_type(dts, prb->dir, res); |
119 | skip_unimplemented_arg_scale(prb->attr, res); |
120 | |
121 | // N.B: Adding this for gpu as cfg is not supported in POST-OPS |
122 | if (is_gpu()) { |
123 | bool have_post_ops = !prb->attr.post_ops.is_def(); |
124 | bool is_bf16u8 = (dts[0] == dnnl_bf16 && dts[1] == dnnl_bf16 |
125 | && dts[2] == dnnl_u8); |
126 | if (is_bf16u8 && have_post_ops) { |
127 | res->state = SKIPPED, res->reason = DATA_TYPE_NOT_SUPPORTED; |
128 | return; |
129 | } |
130 | } |
131 | } |
132 | |
133 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
134 | const bool is_sum = prb->attr.post_ops.find(alg_t::SUM) >= 0; |
135 | bool bcast_src0 = false; |
136 | for (int d = 0; d < prb->ndims; ++d) |
137 | if (prb->vdims[0][d] != prb->vdims[1][d] && prb->vdims[0][d] == 1) { |
138 | bcast_src0 = true; |
139 | break; |
140 | } |
141 | |
142 | // In case src0 is broadcasted into src1, it means that src0 has smaller |
143 | // memory footprint and doing sum post-op or in-place will cause a crash. |
144 | if (bcast_src0 && (prb->inplace || is_sum)) { |
145 | res->state = SKIPPED, res->reason = INVALID_CASE; |
146 | return; |
147 | } |
148 | |
149 | // See `skip_invalid_inplace` for details. |
150 | if (prb->inplace) { |
151 | if (is_sum) { |
152 | res->state = SKIPPED, res->reason = INVALID_CASE; |
153 | return; |
154 | } |
155 | |
156 | skip_invalid_inplace( |
157 | res, prb->sdt[0], prb->ddt, prb->stag[0], prb->dtag); |
158 | if (res->state == SKIPPED) return; |
159 | } |
160 | } |
161 | |
162 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
163 | const args_t &ref_args) { |
164 | cmp.set_threshold(epsilon_dt(prb->ddt)); |
165 | // Since lambda is called when stack is unavailable, need to capture `prb` |
166 | // by value to avoid using dangling references. |
167 | const auto binary_add_check |
168 | = [prb](const compare::compare_t::driver_check_func_args_t &args) { |
169 | // fp16 result can slightly mismatch for division due to |
170 | // difference in backends implementations. |
171 | return prb->alg == alg_t::DIV |
172 | ? args.diff < epsilon_dt(args.dt) |
173 | : false; |
174 | }; |
175 | cmp.set_driver_check_function(binary_add_check); |
176 | |
177 | const std::vector<alg_t> cmp_alg = { |
178 | alg_t::GE, alg_t::GT, alg_t::LE, alg_t::LT, alg_t::EQ, alg_t::NE}; |
179 | const bool is_cmp = std::any_of( |
180 | cmp_alg.cbegin(), cmp_alg.cend(), [&](const alg_t alg) { |
181 | return (prb->alg == alg) || prb->attr.post_ops.find(alg) >= 0; |
182 | }); |
183 | |
184 | if (is_cmp) cmp.set_zero_trust_percent(99.f); |
185 | } |
186 | |
187 | int doit(const prb_t *prb, res_t *res) { |
188 | if (bench_mode == LIST) return res->state = LISTED, OK; |
189 | |
190 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
191 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
192 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
193 | |
194 | auto const_pd = query_pd(prim); |
195 | |
196 | const auto &src0_md = query_md(const_pd, DNNL_ARG_SRC_0); |
197 | const auto &src1_md = query_md(const_pd, DNNL_ARG_SRC_1); |
198 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
199 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
200 | |
201 | const auto fp = dnnl_f32; |
202 | const auto tag = tag::abx; |
203 | |
204 | const auto &test_engine = get_test_engine(); |
205 | const auto &ref_engine = get_cpu_engine(); |
206 | |
207 | dnn_mem_t src0_fp(src0_md, fp, tag, ref_engine); |
208 | dnn_mem_t src0_dt(src0_md, test_engine); |
209 | SAFE(fill_mem(0, src0_dt, src0_fp), WARN); |
210 | |
211 | dnn_mem_t src1_fp(src1_md, fp, tag, ref_engine); |
212 | dnn_mem_t src1_dt(src1_md, test_engine); |
213 | SAFE(fill_mem(1, src1_dt, src1_fp), WARN); |
214 | |
215 | dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine); |
216 | dnn_mem_t dst_dt(dst_md, test_engine); |
217 | if (prb->attr.post_ops.find(alg_t::SUM) >= 0 || is_amd_gpu()) |
218 | SAFE(fill_mem(2, dst_dt, dst_fp), WARN); |
219 | |
220 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
221 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
222 | std::vector<int> binary_po_args; |
223 | SAFE(setup_binary_po(const_pd, binary_po_args, binary_po_dt, binary_po_fp), |
224 | WARN); |
225 | |
226 | args_t args, ref_args; |
227 | |
228 | args.set(DNNL_ARG_SRC_0, src0_dt); |
229 | args.set(DNNL_ARG_SRC_1, src1_dt); |
230 | args.set(DNNL_ARG_DST, dst_dt); |
231 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
232 | args.set(binary_po_args, binary_po_dt); |
233 | |
234 | dnn_mem_t input_scales_m0; |
235 | float scale0 = prb->attr.scales.get(DNNL_ARG_SRC_0).scale; |
236 | maybe_prepare_runtime_scales( |
237 | input_scales_m0, prb->attr.scales.get(DNNL_ARG_SRC_0), 1, &scale0); |
238 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, input_scales_m0); |
239 | dnn_mem_t input_scales_m1; |
240 | float scale1 = prb->attr.scales.get(DNNL_ARG_SRC_1).scale; |
241 | maybe_prepare_runtime_scales( |
242 | input_scales_m1, prb->attr.scales.get(DNNL_ARG_SRC_1), 1, &scale1); |
243 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1, input_scales_m1); |
244 | |
245 | SAFE(execute_and_wait(prim, args, res), WARN); |
246 | |
247 | if (is_bench_mode(CORR)) { |
248 | ref_args.set(DNNL_ARG_SRC_0, src0_fp); |
249 | ref_args.set(DNNL_ARG_SRC_1, src1_fp); |
250 | ref_args.set(DNNL_ARG_DST, dst_fp); |
251 | ref_args.set(binary_po_args, binary_po_fp); |
252 | |
253 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
254 | } |
255 | |
256 | return measure_perf(prb->ctx_exe, res, prim, args); |
257 | } |
258 | |
259 | } // namespace binary |
260 | |