1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstring> |
18 | |
19 | #include <float.h> |
20 | #include <math.h> |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "utils/parallel.hpp" |
27 | |
28 | #include "dnnl_common.hpp" |
29 | #include "dnnl_memory.hpp" |
30 | |
31 | #include "binary/binary.hpp" |
32 | #include "ip/ip.hpp" |
33 | |
34 | namespace ip { |
35 | |
36 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
37 | const prb_t *prb = init_pd_args.prb; |
38 | |
39 | auto src_d = dnn_mem_t::init_md( |
40 | prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag); |
41 | auto wei_d = dnn_mem_t::init_md( |
42 | prb->ndims, prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag); |
43 | auto bia_d = dnn_mem_t::init_md( |
44 | 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any); |
45 | auto dst_d = dnn_mem_t::init_md( |
46 | 2, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag); |
47 | |
48 | attr_args_t attr_args; |
49 | attr_args.prepare_post_ops_mds(prb->attr, 2, prb->dst_dims().data()); |
50 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
51 | create_dnnl_attr(prb->attr, attr_args)); |
52 | |
53 | switch (prb->dir) { |
54 | case FWD_D: |
55 | case FWD_B: |
56 | case FWD_I: |
57 | if (prb->dir != FWD_B) bia_d.reset(nullptr); |
58 | DNN_SAFE_STATUS(dnnl_inner_product_forward_primitive_desc_create( |
59 | &init_pd_args.pd, init_pd_args.engine, |
60 | prb->dir == FWD_I ? dnnl_forward_inference |
61 | : dnnl_forward_training, |
62 | src_d, wei_d, bia_d, dst_d, dnnl_attr)); |
63 | break; |
64 | case BWD_D: |
65 | DNN_SAFE_STATUS( |
66 | dnnl_inner_product_backward_data_primitive_desc_create( |
67 | &init_pd_args.pd, init_pd_args.engine, src_d, wei_d, |
68 | dst_d, init_pd_args.hint, dnnl_attr)); |
69 | break; |
70 | case BWD_W: |
71 | case BWD_WB: |
72 | if (prb->dir == BWD_W) bia_d.reset(nullptr); |
73 | DNN_SAFE_STATUS( |
74 | dnnl_inner_product_backward_weights_primitive_desc_create( |
75 | &init_pd_args.pd, init_pd_args.engine, src_d, wei_d, |
76 | bia_d, dst_d, init_pd_args.hint, dnnl_attr)); |
77 | break; |
78 | default: DNN_SAFE_STATUS(dnnl_invalid_arguments); |
79 | } |
80 | |
81 | return dnnl_success; |
82 | } |
83 | |
84 | int init_prim_ref( |
85 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) { |
86 | if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK; |
87 | |
88 | // Create a new copy of prb to avoid potentially corrupting the test by |
89 | // modifying prb in place. |
90 | auto cpu_attr = prb->attr; |
91 | update_cpu_ref_attrs(cpu_attr); |
92 | prb_t prb_cpu {*prb, prb->mb, prb->dir, conf_f32, tag::abx, tag::abx, |
93 | tag::abx, cpu_attr, prb->ctx_init, prb->ctx_exe}; |
94 | |
95 | init_pd_args_t<prb_t> init_pd_args( |
96 | /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir, |
97 | /* hint = */ nullptr); |
98 | init_pd(init_pd_args); |
99 | |
100 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw; |
101 | fetch_impl(pdw, init_pd_args, /* res = */ nullptr, |
102 | /* is_service_prim = */ true); |
103 | |
104 | dnnl_primitive_t prim_ref_ {}; |
105 | if (pdw) { |
106 | if (query_impl_info(pdw) == "ref:any" ) return OK; |
107 | DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN); |
108 | BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n" , |
109 | query_impl_info(pdw).c_str()); |
110 | } |
111 | prim_ref.reset(prim_ref_); |
112 | return OK; |
113 | } |
114 | |
115 | bool need_src_init(const prb_t *prb) { |
116 | return !(prb->dir == BWD_D); |
117 | } |
118 | |
119 | bool need_wei_init(const prb_t *prb) { |
120 | return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI); |
121 | } |
122 | |
123 | bool need_bia_init(const prb_t *prb) { |
124 | return need_wei_init(prb); |
125 | } |
126 | |
127 | bool need_dst_init(const prb_t *prb) { |
128 | return !(prb->dir & FLAG_FWD) |
129 | || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0); |
130 | } |
131 | |
132 | int fill_src( |
133 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
134 | const auto &c = prb->get_dt_conf(SRC); |
135 | const int range = c.f_max - c.f_min + 1; |
136 | const float sparsity |
137 | = (!is_bench_mode(CORR) || prb->ic < 5) ? 1.f : c.f_sparsity; |
138 | |
139 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw, |
140 | [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) { |
141 | const int gen |
142 | = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic; |
143 | const bool non_base = flip_coin(gen, sparsity); |
144 | const float value |
145 | = non_base ? c.f_min + gen * 1 % range : c.f_base; |
146 | ((float *)mem_fp)[src_off_f(prb, mb, ic, id, ih, iw)] = value; |
147 | }); |
148 | |
149 | SAFE(mem_dt.reorder(mem_fp), WARN); |
150 | |
151 | return OK; |
152 | } |
153 | |
154 | int fill_wei( |
155 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
156 | const bool s8_s8 |
157 | = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[SRC].dt == dnnl_s8; |
158 | |
159 | const auto &c = prb->get_dt_conf(WEI); |
160 | const int range = c.f_max - c.f_min + 1; |
161 | const float sparsity |
162 | = (!is_bench_mode(CORR) || prb->ic < 5) ? 1.f : c.f_sparsity; |
163 | |
164 | benchdnn_parallel_nd(prb->oc, prb->ic, prb->id, prb->ih, prb->iw, |
165 | [&](int64_t oc, int64_t ic, int64_t kd, int64_t kh, int64_t kw) { |
166 | const int gen = 127 * kd + 131 * kh + 137 * kw + 139 * oc |
167 | + 149 * ic + 7; |
168 | const bool non_base = flip_coin(gen, sparsity); |
169 | const float value |
170 | = non_base ? c.f_min + gen * 1 % range : c.f_base; |
171 | ((float *)mem_fp)[wei_off_f(prb, oc, ic, kd, kh, kw)] = value; |
172 | }); |
173 | |
174 | SAFE(mem_dt.reorder(mem_fp), WARN); |
175 | if (s8_s8 && is_cpu()) { |
176 | // Check that s8 -> s8_comp exists in the library since users may have |
177 | // already quantized data. |
178 | dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine()); |
179 | dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine()); |
180 | SAFE(mem_fp_s8.reorder(mem_fp), WARN); |
181 | SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN); |
182 | SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN); |
183 | int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size()); |
184 | SAFE(rc == 0 ? OK : FAIL, WARN); |
185 | } |
186 | |
187 | return OK; |
188 | } |
189 | |
190 | int fill_bia( |
191 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
192 | const size_t nelems = mem_fp.nelems(); |
193 | if (nelems == 0) return OK; |
194 | |
195 | const auto &c = prb->get_dt_conf(BIA); |
196 | const int range = c.f_max - c.f_min + 1; |
197 | |
198 | for (size_t i = 0; i < nelems; ++i) { |
199 | const int gen = (int)(151 * i + 11); |
200 | const bool non_base = flip_coin(gen, c.f_sparsity); |
201 | const float value = non_base ? c.f_min + gen * 1 % range : c.f_base; |
202 | ((float *)mem_fp)[i] = value; |
203 | } |
204 | |
205 | SAFE(mem_dt.reorder(mem_fp), WARN); |
206 | return OK; |
207 | } |
208 | |
209 | int fill_dst( |
210 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
211 | const auto &c = prb->get_dt_conf(DST); |
212 | const int range = c.f_max - c.f_min + 1; |
213 | |
214 | benchdnn_parallel_nd(prb->mb, prb->oc, [&](int64_t mb, int64_t oc) { |
215 | const int gen = 173 * mb + 179 * oc; |
216 | const bool non_base = flip_coin(gen, c.f_sparsity); |
217 | const float value = non_base ? c.f_min + gen * 1 % range : c.f_base; |
218 | |
219 | ((float *)mem_fp)[dst_off_f(prb, mb, oc)] = value; |
220 | }); |
221 | |
222 | SAFE(mem_dt.reorder(mem_fp), WARN); |
223 | |
224 | return OK; |
225 | } |
226 | |
227 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
228 | skip_unimplemented_data_type( |
229 | {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir, |
230 | res); |
231 | |
232 | if (is_cpu()) { |
233 | |
234 | auto is_dt_f16_or_f32 = [&](dnnl_data_type_t dt) { |
235 | return dt == dnnl_f16 || dt == dnnl_f32; |
236 | }; |
237 | |
238 | if (!IMPLICATION(prb->cfg[SRC].dt == dnnl_f16 |
239 | || prb->cfg[WEI].dt == dnnl_f16 |
240 | || prb->cfg[DST].dt == dnnl_f16, |
241 | is_dt_f16_or_f32(prb->cfg[SRC].dt) |
242 | && is_dt_f16_or_f32(prb->cfg[WEI].dt) |
243 | && is_dt_f16_or_f32(prb->cfg[DST].dt))) { |
244 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
245 | } |
246 | } |
247 | |
248 | skip_unimplemented_sum_po(prb->attr, res, prb->get_dt_conf(DST).dt); |
249 | } |
250 | |
251 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
252 | |
253 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
254 | const args_t &ref_args) { |
255 | cmp.set_threshold(prb->cfg[DST].eps); |
256 | |
257 | // TODO: why so bad filling? |
258 | const float zero_trust_percent = kind == WEI || kind == BIA ? 90.f : 80.f; |
259 | cmp.set_zero_trust_percent(zero_trust_percent); |
260 | } |
261 | |
262 | int doit(const prb_t *prb, res_t *res) { |
263 | if (bench_mode == LIST) return res->state = LISTED, OK; |
264 | |
265 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
266 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
267 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
268 | |
269 | auto const_pd = query_pd(prim); |
270 | |
271 | const auto &src_md = prb->dir == BWD_D |
272 | ? query_md(const_pd, DNNL_ARG_DIFF_SRC) |
273 | : query_md(const_pd, DNNL_ARG_SRC); |
274 | const auto &wei_md = prb->dir & FLAG_WEI |
275 | ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS) |
276 | : query_md(const_pd, DNNL_ARG_WEIGHTS); |
277 | const auto &bia_md = prb->dir & FLAG_WEI |
278 | ? query_md(const_pd, DNNL_ARG_DIFF_BIAS) |
279 | : query_md(const_pd, DNNL_ARG_BIAS); |
280 | const auto &dst_md = prb->dir & FLAG_BWD |
281 | ? query_md(const_pd, DNNL_ARG_DIFF_DST) |
282 | : query_md(const_pd, DNNL_ARG_DST); |
283 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
284 | |
285 | const auto fp = dnnl_f32; |
286 | const auto src_tag = tag::abx; |
287 | const auto wei_tag = tag::abx; |
288 | |
289 | // Use CPU prim as the reference in GPU testing to reduce testing time. |
290 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref; |
291 | SAFE(init_prim_ref(prim_ref, prb), WARN); |
292 | |
293 | const auto &test_engine = get_test_engine(); |
294 | const auto &ref_engine = get_cpu_engine(); |
295 | |
296 | dnn_mem_t src_dt(src_md, test_engine); |
297 | dnn_mem_t wei_dt(wei_md, test_engine); |
298 | dnn_mem_t bia_dt(bia_md, test_engine); |
299 | dnn_mem_t dst_dt(dst_md, test_engine); |
300 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
301 | dnn_mem_t src_scales_dt, src_scales_fp; |
302 | dnn_mem_t wei_scales_dt, wei_scales_fp; |
303 | dnn_mem_t dst_scales_dt, dst_scales_fp; |
304 | const int src_mask = attr_t::get_default_mask( |
305 | prb->attr.scales.get(DNNL_ARG_SRC).policy); |
306 | const int wei_mask = attr_t::get_default_mask( |
307 | prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS); |
308 | const int dst_mask = attr_t::get_default_mask( |
309 | prb->attr.scales.get(DNNL_ARG_DST).policy); |
310 | maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp, |
311 | prb->attr.scales.get(DNNL_ARG_SRC), |
312 | prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales); |
313 | maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp, |
314 | prb->attr.scales.get(DNNL_ARG_WEIGHTS), |
315 | prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales); |
316 | maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp, |
317 | prb->attr.scales.get(DNNL_ARG_DST), |
318 | prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales); |
319 | |
320 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
321 | std::vector<int> binary_po_args; |
322 | SAFE(binary::setup_binary_po( |
323 | const_pd, binary_po_args, binary_po_dt, binary_po_fp), |
324 | WARN); |
325 | |
326 | dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine); |
327 | dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine); |
328 | dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine); |
329 | dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine); |
330 | |
331 | if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN); |
332 | if (need_wei_init(prb)) SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN); |
333 | if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN); |
334 | if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN); |
335 | |
336 | dnn_mem_t scratchpad_fp; |
337 | if (prim_ref) |
338 | scratchpad_fp = dnn_mem_t( |
339 | query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine); |
340 | |
341 | args_t args, ref_args; |
342 | |
343 | if (prb->dir & FLAG_FWD) { |
344 | args.set(DNNL_ARG_SRC, src_dt); |
345 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
346 | args.set(DNNL_ARG_BIAS, bia_dt); |
347 | args.set(DNNL_ARG_DST, dst_dt); |
348 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
349 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
350 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt); |
351 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
352 | args.set(binary_po_args, binary_po_dt); |
353 | |
354 | SAFE(execute_and_wait(prim, args, res), WARN); |
355 | |
356 | if (is_bench_mode(CORR)) { |
357 | ref_args.set(DNNL_ARG_SRC, src_fp); |
358 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
359 | ref_args.set(DNNL_ARG_BIAS, bia_fp); |
360 | ref_args.set(DNNL_ARG_DST, dst_fp); |
361 | ref_args.set(binary_po_args, binary_po_fp); |
362 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
363 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
364 | ref_args.set( |
365 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt); |
366 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
367 | |
368 | check_correctness( |
369 | prb, {DST}, args, ref_args, setup_cmp, res, prim_ref); |
370 | } |
371 | } else if (prb->dir == BWD_D) { |
372 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
373 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
374 | args.set(DNNL_ARG_DIFF_SRC, src_dt); |
375 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
376 | |
377 | SAFE(execute_and_wait(prim, args, res), WARN); |
378 | |
379 | if (is_bench_mode(CORR)) { |
380 | ref_args.set(DNNL_ARG_DIFF_SRC, src_fp); |
381 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
382 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
383 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
384 | |
385 | check_correctness( |
386 | prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref); |
387 | } |
388 | } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) { |
389 | args.set(DNNL_ARG_SRC, src_dt); |
390 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
391 | args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt); |
392 | args.set(DNNL_ARG_DIFF_BIAS, bia_dt); |
393 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
394 | |
395 | SAFE(execute_and_wait(prim, args, res), WARN); |
396 | |
397 | if (is_bench_mode(CORR)) { |
398 | ref_args.set(DNNL_ARG_SRC, src_fp); |
399 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp); |
400 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
401 | ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp); |
402 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
403 | |
404 | check_correctness( |
405 | prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref); |
406 | } |
407 | } |
408 | |
409 | return measure_perf(prb->ctx_exe, res, prim, args); |
410 | } |
411 | |
412 | } // namespace ip |
413 | |