1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstring>
18
19#include <float.h>
20#include <math.h>
21#include <stdio.h>
22#include <stdlib.h>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "utils/parallel.hpp"
27
28#include "dnnl_common.hpp"
29#include "dnnl_memory.hpp"
30
31#include "binary/binary.hpp"
32#include "ip/ip.hpp"
33
34namespace ip {
35
36dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
37 const prb_t *prb = init_pd_args.prb;
38
39 auto src_d = dnn_mem_t::init_md(
40 prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag);
41 auto wei_d = dnn_mem_t::init_md(
42 prb->ndims, prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag);
43 auto bia_d = dnn_mem_t::init_md(
44 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any);
45 auto dst_d = dnn_mem_t::init_md(
46 2, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag);
47
48 attr_args_t attr_args;
49 attr_args.prepare_post_ops_mds(prb->attr, 2, prb->dst_dims().data());
50 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
51 create_dnnl_attr(prb->attr, attr_args));
52
53 switch (prb->dir) {
54 case FWD_D:
55 case FWD_B:
56 case FWD_I:
57 if (prb->dir != FWD_B) bia_d.reset(nullptr);
58 DNN_SAFE_STATUS(dnnl_inner_product_forward_primitive_desc_create(
59 &init_pd_args.pd, init_pd_args.engine,
60 prb->dir == FWD_I ? dnnl_forward_inference
61 : dnnl_forward_training,
62 src_d, wei_d, bia_d, dst_d, dnnl_attr));
63 break;
64 case BWD_D:
65 DNN_SAFE_STATUS(
66 dnnl_inner_product_backward_data_primitive_desc_create(
67 &init_pd_args.pd, init_pd_args.engine, src_d, wei_d,
68 dst_d, init_pd_args.hint, dnnl_attr));
69 break;
70 case BWD_W:
71 case BWD_WB:
72 if (prb->dir == BWD_W) bia_d.reset(nullptr);
73 DNN_SAFE_STATUS(
74 dnnl_inner_product_backward_weights_primitive_desc_create(
75 &init_pd_args.pd, init_pd_args.engine, src_d, wei_d,
76 bia_d, dst_d, init_pd_args.hint, dnnl_attr));
77 break;
78 default: DNN_SAFE_STATUS(dnnl_invalid_arguments);
79 }
80
81 return dnnl_success;
82}
83
84int init_prim_ref(
85 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
86 if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
87
88 // Create a new copy of prb to avoid potentially corrupting the test by
89 // modifying prb in place.
90 auto cpu_attr = prb->attr;
91 update_cpu_ref_attrs(cpu_attr);
92 prb_t prb_cpu {*prb, prb->mb, prb->dir, conf_f32, tag::abx, tag::abx,
93 tag::abx, cpu_attr, prb->ctx_init, prb->ctx_exe};
94
95 init_pd_args_t<prb_t> init_pd_args(
96 /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
97 /* hint = */ nullptr);
98 init_pd(init_pd_args);
99
100 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
101 fetch_impl(pdw, init_pd_args, /* res = */ nullptr,
102 /* is_service_prim = */ true);
103
104 dnnl_primitive_t prim_ref_ {};
105 if (pdw) {
106 if (query_impl_info(pdw) == "ref:any") return OK;
107 DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN);
108 BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n",
109 query_impl_info(pdw).c_str());
110 }
111 prim_ref.reset(prim_ref_);
112 return OK;
113}
114
115bool need_src_init(const prb_t *prb) {
116 return !(prb->dir == BWD_D);
117}
118
119bool need_wei_init(const prb_t *prb) {
120 return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI);
121}
122
123bool need_bia_init(const prb_t *prb) {
124 return need_wei_init(prb);
125}
126
127bool need_dst_init(const prb_t *prb) {
128 return !(prb->dir & FLAG_FWD)
129 || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0);
130}
131
132int fill_src(
133 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
134 const auto &c = prb->get_dt_conf(SRC);
135 const int range = c.f_max - c.f_min + 1;
136 const float sparsity
137 = (!is_bench_mode(CORR) || prb->ic < 5) ? 1.f : c.f_sparsity;
138
139 benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw,
140 [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) {
141 const int gen
142 = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic;
143 const bool non_base = flip_coin(gen, sparsity);
144 const float value
145 = non_base ? c.f_min + gen * 1 % range : c.f_base;
146 ((float *)mem_fp)[src_off_f(prb, mb, ic, id, ih, iw)] = value;
147 });
148
149 SAFE(mem_dt.reorder(mem_fp), WARN);
150
151 return OK;
152}
153
154int fill_wei(
155 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
156 const bool s8_s8
157 = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[SRC].dt == dnnl_s8;
158
159 const auto &c = prb->get_dt_conf(WEI);
160 const int range = c.f_max - c.f_min + 1;
161 const float sparsity
162 = (!is_bench_mode(CORR) || prb->ic < 5) ? 1.f : c.f_sparsity;
163
164 benchdnn_parallel_nd(prb->oc, prb->ic, prb->id, prb->ih, prb->iw,
165 [&](int64_t oc, int64_t ic, int64_t kd, int64_t kh, int64_t kw) {
166 const int gen = 127 * kd + 131 * kh + 137 * kw + 139 * oc
167 + 149 * ic + 7;
168 const bool non_base = flip_coin(gen, sparsity);
169 const float value
170 = non_base ? c.f_min + gen * 1 % range : c.f_base;
171 ((float *)mem_fp)[wei_off_f(prb, oc, ic, kd, kh, kw)] = value;
172 });
173
174 SAFE(mem_dt.reorder(mem_fp), WARN);
175 if (s8_s8 && is_cpu()) {
176 // Check that s8 -> s8_comp exists in the library since users may have
177 // already quantized data.
178 dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine());
179 dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine());
180 SAFE(mem_fp_s8.reorder(mem_fp), WARN);
181 SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN);
182 SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN);
183 int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size());
184 SAFE(rc == 0 ? OK : FAIL, WARN);
185 }
186
187 return OK;
188}
189
190int fill_bia(
191 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
192 const size_t nelems = mem_fp.nelems();
193 if (nelems == 0) return OK;
194
195 const auto &c = prb->get_dt_conf(BIA);
196 const int range = c.f_max - c.f_min + 1;
197
198 for (size_t i = 0; i < nelems; ++i) {
199 const int gen = (int)(151 * i + 11);
200 const bool non_base = flip_coin(gen, c.f_sparsity);
201 const float value = non_base ? c.f_min + gen * 1 % range : c.f_base;
202 ((float *)mem_fp)[i] = value;
203 }
204
205 SAFE(mem_dt.reorder(mem_fp), WARN);
206 return OK;
207}
208
209int fill_dst(
210 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
211 const auto &c = prb->get_dt_conf(DST);
212 const int range = c.f_max - c.f_min + 1;
213
214 benchdnn_parallel_nd(prb->mb, prb->oc, [&](int64_t mb, int64_t oc) {
215 const int gen = 173 * mb + 179 * oc;
216 const bool non_base = flip_coin(gen, c.f_sparsity);
217 const float value = non_base ? c.f_min + gen * 1 % range : c.f_base;
218
219 ((float *)mem_fp)[dst_off_f(prb, mb, oc)] = value;
220 });
221
222 SAFE(mem_dt.reorder(mem_fp), WARN);
223
224 return OK;
225}
226
227void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
228 skip_unimplemented_data_type(
229 {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir,
230 res);
231
232 if (is_cpu()) {
233
234 auto is_dt_f16_or_f32 = [&](dnnl_data_type_t dt) {
235 return dt == dnnl_f16 || dt == dnnl_f32;
236 };
237
238 if (!IMPLICATION(prb->cfg[SRC].dt == dnnl_f16
239 || prb->cfg[WEI].dt == dnnl_f16
240 || prb->cfg[DST].dt == dnnl_f16,
241 is_dt_f16_or_f32(prb->cfg[SRC].dt)
242 && is_dt_f16_or_f32(prb->cfg[WEI].dt)
243 && is_dt_f16_or_f32(prb->cfg[DST].dt))) {
244 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
245 }
246 }
247
248 skip_unimplemented_sum_po(prb->attr, res, prb->get_dt_conf(DST).dt);
249}
250
251void skip_invalid_prb(const prb_t *prb, res_t *res) {}
252
253void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
254 const args_t &ref_args) {
255 cmp.set_threshold(prb->cfg[DST].eps);
256
257 // TODO: why so bad filling?
258 const float zero_trust_percent = kind == WEI || kind == BIA ? 90.f : 80.f;
259 cmp.set_zero_trust_percent(zero_trust_percent);
260}
261
262int doit(const prb_t *prb, res_t *res) {
263 if (bench_mode == LIST) return res->state = LISTED, OK;
264
265 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
266 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
267 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
268
269 auto const_pd = query_pd(prim);
270
271 const auto &src_md = prb->dir == BWD_D
272 ? query_md(const_pd, DNNL_ARG_DIFF_SRC)
273 : query_md(const_pd, DNNL_ARG_SRC);
274 const auto &wei_md = prb->dir & FLAG_WEI
275 ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS)
276 : query_md(const_pd, DNNL_ARG_WEIGHTS);
277 const auto &bia_md = prb->dir & FLAG_WEI
278 ? query_md(const_pd, DNNL_ARG_DIFF_BIAS)
279 : query_md(const_pd, DNNL_ARG_BIAS);
280 const auto &dst_md = prb->dir & FLAG_BWD
281 ? query_md(const_pd, DNNL_ARG_DIFF_DST)
282 : query_md(const_pd, DNNL_ARG_DST);
283 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
284
285 const auto fp = dnnl_f32;
286 const auto src_tag = tag::abx;
287 const auto wei_tag = tag::abx;
288
289 // Use CPU prim as the reference in GPU testing to reduce testing time.
290 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
291 SAFE(init_prim_ref(prim_ref, prb), WARN);
292
293 const auto &test_engine = get_test_engine();
294 const auto &ref_engine = get_cpu_engine();
295
296 dnn_mem_t src_dt(src_md, test_engine);
297 dnn_mem_t wei_dt(wei_md, test_engine);
298 dnn_mem_t bia_dt(bia_md, test_engine);
299 dnn_mem_t dst_dt(dst_md, test_engine);
300 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
301 dnn_mem_t src_scales_dt, src_scales_fp;
302 dnn_mem_t wei_scales_dt, wei_scales_fp;
303 dnn_mem_t dst_scales_dt, dst_scales_fp;
304 const int src_mask = attr_t::get_default_mask(
305 prb->attr.scales.get(DNNL_ARG_SRC).policy);
306 const int wei_mask = attr_t::get_default_mask(
307 prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS);
308 const int dst_mask = attr_t::get_default_mask(
309 prb->attr.scales.get(DNNL_ARG_DST).policy);
310 maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp,
311 prb->attr.scales.get(DNNL_ARG_SRC),
312 prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales);
313 maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp,
314 prb->attr.scales.get(DNNL_ARG_WEIGHTS),
315 prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales);
316 maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp,
317 prb->attr.scales.get(DNNL_ARG_DST),
318 prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales);
319
320 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
321 std::vector<int> binary_po_args;
322 SAFE(binary::setup_binary_po(
323 const_pd, binary_po_args, binary_po_dt, binary_po_fp),
324 WARN);
325
326 dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine);
327 dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine);
328 dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
329 dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine);
330
331 if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
332 if (need_wei_init(prb)) SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN);
333 if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN);
334 if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
335
336 dnn_mem_t scratchpad_fp;
337 if (prim_ref)
338 scratchpad_fp = dnn_mem_t(
339 query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine);
340
341 args_t args, ref_args;
342
343 if (prb->dir & FLAG_FWD) {
344 args.set(DNNL_ARG_SRC, src_dt);
345 args.set(DNNL_ARG_WEIGHTS, wei_dt);
346 args.set(DNNL_ARG_BIAS, bia_dt);
347 args.set(DNNL_ARG_DST, dst_dt);
348 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
349 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
350 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt);
351 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
352 args.set(binary_po_args, binary_po_dt);
353
354 SAFE(execute_and_wait(prim, args, res), WARN);
355
356 if (is_bench_mode(CORR)) {
357 ref_args.set(DNNL_ARG_SRC, src_fp);
358 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
359 ref_args.set(DNNL_ARG_BIAS, bia_fp);
360 ref_args.set(DNNL_ARG_DST, dst_fp);
361 ref_args.set(binary_po_args, binary_po_fp);
362 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
363 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
364 ref_args.set(
365 DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt);
366 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
367
368 check_correctness(
369 prb, {DST}, args, ref_args, setup_cmp, res, prim_ref);
370 }
371 } else if (prb->dir == BWD_D) {
372 args.set(DNNL_ARG_DIFF_DST, dst_dt);
373 args.set(DNNL_ARG_WEIGHTS, wei_dt);
374 args.set(DNNL_ARG_DIFF_SRC, src_dt);
375 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
376
377 SAFE(execute_and_wait(prim, args, res), WARN);
378
379 if (is_bench_mode(CORR)) {
380 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
381 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
382 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
383 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
384
385 check_correctness(
386 prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref);
387 }
388 } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
389 args.set(DNNL_ARG_SRC, src_dt);
390 args.set(DNNL_ARG_DIFF_DST, dst_dt);
391 args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt);
392 args.set(DNNL_ARG_DIFF_BIAS, bia_dt);
393 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
394
395 SAFE(execute_and_wait(prim, args, res), WARN);
396
397 if (is_bench_mode(CORR)) {
398 ref_args.set(DNNL_ARG_SRC, src_fp);
399 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp);
400 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
401 ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp);
402 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
403
404 check_correctness(
405 prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref);
406 }
407 }
408
409 return measure_perf(prb->ctx_exe, res, prim, args);
410}
411
412} // namespace ip
413