1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | #include <utility> |
19 | |
20 | #include <float.h> |
21 | #include <math.h> |
22 | #include <stdio.h> |
23 | #include <stdlib.h> |
24 | |
25 | #include "utils/parallel.hpp" |
26 | |
27 | #include "dnnl_common.hpp" |
28 | #include "dnnl_memory.hpp" |
29 | |
30 | #include "binary/binary.hpp" |
31 | #include "deconv/deconv.hpp" |
32 | #include "prelu/prelu.hpp" |
33 | |
34 | namespace deconv { |
35 | |
36 | int transpose_data_wei( |
37 | const prb_t *prb, const dnn_mem_t &wei, const dnn_mem_t &wei_tr) { |
38 | benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd, |
39 | prb->kh, prb->kw, |
40 | [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh, |
41 | int64_t kw) { |
42 | int64_t ch_idx |
43 | = (g * prb->ic / prb->g + ic) * prb->oc / prb->g + oc; |
44 | int64_t idx = ((ch_idx * prb->kd + kd) * prb->kh + kh) * prb->kw |
45 | + kw; |
46 | ((float *)wei_tr)[idx] |
47 | = ((float *)wei)[wei_off_f(prb, g, oc, ic, kd, kh, kw)]; |
48 | }); |
49 | |
50 | return OK; |
51 | } |
52 | |
53 | double get_non_zero_trust_percent(const prb_t *prb, data_kind_t kind) { |
54 | auto negative_to_zero = [&]() { |
55 | using pk = attr_t::post_ops_t::kind_t; |
56 | const auto &po = prb->attr.post_ops; |
57 | int count = 0; |
58 | |
59 | // Check for all post-ops that convert negative to zero |
60 | std::vector<pk> non_neg_po {pk::ABS}; |
61 | std::vector<pk> non_neg_alpha_0_po { |
62 | pk::CLIP, pk::CLIP_V2, pk::ELU, pk::RELU}; |
63 | for (int i = 0; i < po.len(); ++i) { |
64 | const auto &e = po.entry[i]; |
65 | if (!e.is_eltwise_kind()) continue; |
66 | |
67 | auto k = e.kind; |
68 | auto alpha = e.eltwise.alpha; |
69 | |
70 | count += std::any_of(non_neg_po.cbegin(), non_neg_po.cend(), |
71 | [k](const pk alg) { return alg == k; }); |
72 | count += std::any_of(non_neg_alpha_0_po.cbegin(), |
73 | non_neg_alpha_0_po.cend(), [k, alpha](const pk alg) { |
74 | return alg == k && alpha == 0; |
75 | }); |
76 | } |
77 | // Check for u8 dst |
78 | count += prb->cfg[DST].dt == dnnl_u8; |
79 | // Check for physically padded area in the output |
80 | count += prb->od > prb->id || prb->oh > prb->ih || prb->ow > prb->iw; |
81 | |
82 | return !!count; |
83 | }; |
84 | |
85 | double trust = 0.3; /* why? */ |
86 | switch (kind) { |
87 | case SRC: trust /= prb->sd * prb->sh * prb->sw; break; |
88 | case WEI: |
89 | trust /= 1. * prb->kd * prb->kh * prb->kw |
90 | / MIN3(prb->kd * prb->kh * prb->kw, |
91 | prb->id * prb->ih * prb->iw, |
92 | prb->od * prb->oh * prb->ow); |
93 | break; |
94 | case BIA: |
95 | trust = 0.8 * prb->cfg[DST].f_sparsity; /* why? */ |
96 | break; |
97 | case DST: trust /= (1.f + negative_to_zero()); break; |
98 | default: assert(!"unsupported data kind" ); |
99 | } |
100 | |
101 | return trust; |
102 | } |
103 | |
104 | bool need_src_init(const prb_t *prb) { |
105 | return !(prb->dir == BWD_D); |
106 | } |
107 | |
108 | bool need_wei_init(const prb_t *prb) { |
109 | return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI); |
110 | } |
111 | |
112 | bool need_bia_init(const prb_t *prb) { |
113 | return need_wei_init(prb); |
114 | } |
115 | |
116 | bool need_dst_init(const prb_t *prb) { |
117 | return !(prb->dir & FLAG_FWD) |
118 | || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0); |
119 | } |
120 | |
121 | int fill_src( |
122 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
123 | const bool check_reorder |
124 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
125 | dnn_mem_t ; |
126 | if (check_reorder) { |
127 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
128 | } |
129 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
130 | |
131 | // Use dense filling for small problems. |
132 | int src_nelems_mask = powf(2.f, prb->ndims) - 1; |
133 | src_nelems_mask -= 1; // remove minibatch as independent dimension |
134 | auto src_nelems = prb->desc_nelems(DNNL_ARG_SRC, src_nelems_mask); |
135 | if (prb->has_groups) src_nelems /= prb->g; // groups are also independent |
136 | |
137 | const auto &c = prb->cfg[SRC]; |
138 | const int range = c.f_max - c.f_min + 1; |
139 | const float sparsity |
140 | = (!is_bench_mode(CORR) || src_nelems < 100) ? 1.f : c.f_sparsity; |
141 | |
142 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw, |
143 | [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) { |
144 | const int64_t gen |
145 | = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic; |
146 | const bool non_base = flip_coin(gen, sparsity); |
147 | float value = non_base ? c.f_min + gen * c.f_step % range |
148 | : c.f_base; |
149 | |
150 | maybe_zero_point( |
151 | prb->attr, value, prb->src_zp, ic, DNNL_ARG_SRC, true); |
152 | |
153 | ((float *)mem_00)[src_off_f(prb, mb, 0, ic, id, ih, iw)] |
154 | = round_to_nearest_representable(mem_dt.dt(), value); |
155 | }); |
156 | |
157 | SAFE(mem_dt.reorder(mem_00), WARN); |
158 | if (check_reorder) { |
159 | SAFE(mem_fp.reorder(mem_dt), WARN); |
160 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
161 | if (rc != 0) { |
162 | res->state = FAILED; |
163 | SAFE(FAIL, WARN); |
164 | } |
165 | } |
166 | |
167 | return OK; |
168 | } |
169 | |
170 | int fill_wei( |
171 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
172 | const bool is_def_zp = prb->attr.zero_points.is_def(DNNL_ARG_SRC); |
173 | const bool diff_data_type = mem_dt.dt() != mem_fp.dt(); |
174 | |
175 | dnnl_data_type_t dt_check = dnnl_s8; |
176 | #if DNNL_AARCH64 |
177 | /* Note for x64: |
178 | Both data types of src and weight are s8, oneDNN addds 128 to one of the s8 |
179 | input to make it of type u8 instead, as explained in |
180 | https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html or |
181 | doc/advanced/int8_computations.md |
182 | It is because `VPDPBUSD` instruction uses the combination of s8 and u8 as |
183 | input. |
184 | |
185 | Note for AArch64: |
186 | Because dot product instructions of AArch64 "SDOT" receives s8 input |
187 | for both src and weight, the addition (and its counterpart of subtraction) |
188 | is not required for AArch64. |
189 | */ |
190 | if (res->impl_name.find("jit" , 0) == 0) dt_check = dnnl_u8; |
191 | #endif |
192 | |
193 | const bool wei_x8x8 |
194 | = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[SRC].dt == dt_check; |
195 | const bool check_reorder |
196 | = (is_bench_mode(CORR)) && diff_data_type && !wei_x8x8 && is_def_zp; |
197 | |
198 | dnn_mem_t ; |
199 | if (check_reorder) { |
200 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
201 | } |
202 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
203 | |
204 | const auto &c = prb->cfg[WEI]; |
205 | const int range = c.f_max - c.f_min + 1; |
206 | const float sparsity = !is_bench_mode(CORR) ? 1.f : c.f_sparsity; |
207 | |
208 | benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd, |
209 | prb->kh, prb->kw, |
210 | [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh, |
211 | int64_t kw) { |
212 | const int64_t gen = 113 * g + 127 * kd + 131 * kh + 137 * kw |
213 | + 139 * oc + 149 * ic + 151; |
214 | const bool non_base = flip_coin(gen, sparsity); |
215 | const float value = non_base ? c.f_min + gen * c.f_step % range |
216 | : c.f_base; |
217 | ((float *)mem_00)[wei_off_f(prb, g, oc, ic, kd, kh, kw)] |
218 | = value; |
219 | }); |
220 | |
221 | SAFE(mem_dt.reorder(mem_00), WARN); |
222 | if (check_reorder) { |
223 | SAFE(mem_fp.reorder(mem_dt), WARN); |
224 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
225 | if (rc != 0) { |
226 | res->state = FAILED; |
227 | SAFE(FAIL, WARN); |
228 | } |
229 | } |
230 | if ((wei_x8x8 || !is_def_zp) && is_cpu()) { |
231 | // Check that s8 -> s8_comp exists in the library since users may have |
232 | // already quantized data. |
233 | dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine()); |
234 | dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine()); |
235 | SAFE(mem_fp_s8.reorder(mem_fp), WARN); |
236 | SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN); |
237 | SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN); |
238 | int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size()); |
239 | SAFE(rc == 0 ? OK : FAIL, WARN); |
240 | } |
241 | return OK; |
242 | } |
243 | |
244 | int fill_bia( |
245 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
246 | const bool check_reorder |
247 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
248 | dnn_mem_t ; |
249 | if (check_reorder) |
250 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::x, get_cpu_engine()); |
251 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
252 | |
253 | const size_t nelems = mem_00.nelems(); |
254 | if (nelems == 0) return OK; |
255 | |
256 | const auto &c = prb->cfg[BIA]; |
257 | const int range = c.f_max - c.f_min + 1; |
258 | |
259 | for (size_t i = 0; i < nelems; ++i) { |
260 | const int gen = (int)(151 * i); |
261 | const bool non_base = flip_coin(gen, c.f_sparsity); |
262 | const float value |
263 | = non_base ? c.f_min + gen * c.f_step % range : c.f_base; |
264 | |
265 | ((float *)mem_00)[i] = value; |
266 | } |
267 | |
268 | SAFE(mem_dt.reorder(mem_00), WARN); |
269 | if (check_reorder) { |
270 | SAFE(mem_fp.reorder(mem_dt), WARN); |
271 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
272 | if (rc != 0) { |
273 | res->state = FAILED; |
274 | SAFE(FAIL, WARN); |
275 | } |
276 | } |
277 | return OK; |
278 | } |
279 | |
280 | int fill_dst_with_params(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, |
281 | dnnl_data_type_t dt, double sparsity, int min, int max, int base, |
282 | int step, res_t *res) { |
283 | const bool check_reorder |
284 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
285 | dnn_mem_t ; |
286 | if (check_reorder) { |
287 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
288 | } |
289 | |
290 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
291 | const int range = max - min + 1; |
292 | |
293 | benchdnn_parallel_nd(prb->mb, prb->oc, prb->od, prb->oh, prb->ow, |
294 | [&](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) { |
295 | const int64_t gen |
296 | = 157 * od + 163 * oh + 167 * ow + 173 * mb + 179 * oc; |
297 | const bool non_base = flip_coin(gen, sparsity); |
298 | const float value = non_base ? min + gen * step % range : base; |
299 | |
300 | ((float *)mem_00)[dst_off_f(prb, mb, 0, oc, od, oh, ow)] |
301 | = value; |
302 | }); |
303 | |
304 | SAFE(mem_dt.reorder(mem_00), WARN); |
305 | if (check_reorder) { |
306 | SAFE(mem_fp.reorder(mem_dt), WARN); |
307 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
308 | if (rc != 0) { |
309 | res->state = FAILED; |
310 | SAFE(FAIL, WARN); |
311 | } |
312 | } |
313 | |
314 | return OK; |
315 | } |
316 | |
317 | int fill_dst( |
318 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
319 | auto dst_dt = mem_dt.dt(); |
320 | int sum_ind = prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM); |
321 | auto sum_dt = (sum_ind != -1) ? prb->attr.post_ops.entry[sum_ind].sum.dt |
322 | : dnnl_data_type_undef; |
323 | bool diff_sum_dst_types |
324 | = sum_dt != dnnl_data_type_undef && sum_dt != dst_dt; |
325 | bool sum_dt_is_int8 = sum_dt == dnnl_s8 || sum_dt == dnnl_u8; |
326 | |
327 | const auto &c = prb->cfg[DST]; |
328 | float f_min = c.f_min; |
329 | float f_max = c.f_max; |
330 | if (diff_sum_dst_types && sum_dt_is_int8) { |
331 | f_min = lowest_dt(sum_dt); |
332 | f_max = max_dt(sum_dt); |
333 | } |
334 | |
335 | // Change mem dt to sum dt, so we can save sum data properly. |
336 | if (diff_sum_dst_types) { mem_dt.set_dt(sum_dt); } |
337 | |
338 | fill_dst_with_params(prb, mem_dt, mem_fp, sum_dt, c.f_sparsity, f_min, |
339 | f_max, c.f_base, c.f_step, res); |
340 | |
341 | // Return dst data type back. |
342 | if (diff_sum_dst_types) { mem_dt.set_dt(dst_dt); } |
343 | return OK; |
344 | } |
345 | |
346 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
347 | const prb_t *prb = init_pd_args.prb; |
348 | |
349 | auto src_d = dnn_mem_t::init_md( |
350 | prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, prb->stag); |
351 | auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups, |
352 | prb->wei_dims().data(), prb->cfg[WEI].dt, prb->wtag); |
353 | auto bia_d = dnn_mem_t::init_md( |
354 | 1, prb->bia_dims().data(), prb->cfg[BIA].dt, tag::any); |
355 | auto dst_d = dnn_mem_t::init_md( |
356 | prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, prb->dtag); |
357 | |
358 | dnnl_alg_kind_t alg = dnnl_deconvolution_direct; |
359 | if (prb->alg == WINO) alg = dnnl_deconvolution_winograd; |
360 | |
361 | attr_args_t attr_args; |
362 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
363 | if (wei_scale.policy == policy_t::PER_OC) { |
364 | auto wei_mask = prb->has_groups ? 2 : 1; |
365 | attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales, |
366 | prb->oc, wei_mask); |
367 | } |
368 | attr_args.prepare_post_ops_mds( |
369 | prb->attr, prb->ndims, prb->dst_dims().data()); |
370 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
371 | create_dnnl_attr(prb->attr, attr_args)); |
372 | |
373 | switch (prb->dir) { |
374 | case FWD_D: |
375 | case FWD_B: |
376 | case FWD_I: |
377 | if (prb->dir != FWD_B) bia_d.reset(nullptr); |
378 | DNN_SAFE_STATUS(dnnl_deconvolution_forward_primitive_desc_create( |
379 | &init_pd_args.pd, init_pd_args.engine, |
380 | prb->dir == FWD_I ? dnnl_forward_inference |
381 | : dnnl_forward_training, |
382 | alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(), |
383 | prb->dilations().data(), prb->padding().data(), |
384 | prb->padding_r().data(), dnnl_attr)); |
385 | break; |
386 | case BWD_D: |
387 | DNN_SAFE_STATUS( |
388 | dnnl_deconvolution_backward_data_primitive_desc_create( |
389 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
390 | wei_d, dst_d, prb->strides().data(), |
391 | prb->dilations().data(), prb->padding().data(), |
392 | prb->padding_r().data(), init_pd_args.hint, |
393 | dnnl_attr)); |
394 | break; |
395 | case BWD_W: |
396 | case BWD_WB: |
397 | if (prb->dir == BWD_W) bia_d.reset(nullptr); |
398 | DNN_SAFE_STATUS( |
399 | dnnl_deconvolution_backward_weights_primitive_desc_create( |
400 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
401 | wei_d, bia_d, dst_d, prb->strides().data(), |
402 | prb->dilations().data(), prb->padding().data(), |
403 | prb->padding_r().data(), init_pd_args.hint, |
404 | dnnl_attr)); |
405 | break; |
406 | default: DNN_SAFE_STATUS(dnnl_invalid_arguments); |
407 | } |
408 | |
409 | // TODO: enabled back when query is added to pd?? |
410 | //DNN_SAFE_STATUS(cd.accum_data_type == prb->cfg[ACC].dt |
411 | // ? dnnl_success |
412 | // : dnnl_unimplemented); |
413 | |
414 | return dnnl_success; |
415 | } |
416 | |
417 | int init_prim_ref( |
418 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) { |
419 | if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK; |
420 | |
421 | // Create a new copy of prb to avoid potentially corrupting the test by |
422 | // modifying prb in place. |
423 | // DIRECT algorithm is used to prevent fallback to the slow benchdnn |
424 | // reference implementation. |
425 | auto cpu_attr = prb->attr; |
426 | update_cpu_ref_attrs(cpu_attr); |
427 | prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx, |
428 | DIRECT, cpu_attr, prb->ctx_init, prb->ctx_exe, prb->mb}; |
429 | |
430 | init_pd_args_t<prb_t> init_pd_args( |
431 | /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir, |
432 | /* hint = */ nullptr); |
433 | init_pd(init_pd_args); |
434 | |
435 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw; |
436 | fetch_impl(pdw, init_pd_args, /* res = */ nullptr, |
437 | /* is_service_prim = */ true); |
438 | |
439 | dnnl_primitive_t prim_ref_ {}; |
440 | if (pdw) { |
441 | if (query_impl_info(pdw) == "ref:any" ) return OK; |
442 | DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN); |
443 | BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n" , |
444 | query_impl_info(pdw).c_str()); |
445 | } |
446 | prim_ref.reset(prim_ref_); |
447 | return OK; |
448 | } |
449 | |
450 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
451 | skip_unimplemented_data_type( |
452 | {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir, |
453 | res); |
454 | skip_unimplemented_sum_po(prb->attr, res); |
455 | |
456 | // GPU supports only post ops and all but x8s8bf16 cfg |
457 | if (is_gpu()) { |
458 | const bool is_x8s8bf16_cfg |
459 | = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[DST].dt == dnnl_bf16; |
460 | const bool fwd_ok = !is_x8s8bf16_cfg; |
461 | if (!fwd_ok) { |
462 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
463 | return; |
464 | } |
465 | } |
466 | } |
467 | |
468 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
469 | |
470 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
471 | const args_t &ref_args) { |
472 | const bool compare_with_norm = (prb->alg & WINO); |
473 | cmp.set_norm_validation_mode(compare_with_norm); |
474 | |
475 | float trh = prb->cfg[kind].eps; |
476 | if ((prb->alg & WINO) && (prb->dir & FLAG_WEI)) { |
477 | // This is an empirical equation derived by observing growth error with |
478 | // increasing 'k' dimension in gemm of winograd |
479 | const float log_const = log10(0.125 * prb->mb * prb->oh * prb->ow); |
480 | trh = prb->cfg[kind].eps * (MAX2(1, pow(10, 0.4 * log_const))); |
481 | } |
482 | cmp.set_threshold(trh); |
483 | |
484 | const float zpp = (1.f - get_non_zero_trust_percent(prb, kind)) * 100.f; |
485 | cmp.set_zero_trust_percent(zpp); |
486 | } |
487 | |
488 | int doit(const prb_t *prb, res_t *res) { |
489 | if (bench_mode == LIST) return res->state = LISTED, OK; |
490 | |
491 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
492 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
493 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
494 | |
495 | auto const_pd = query_pd(prim); |
496 | |
497 | const auto &src_md = prb->dir == BWD_D |
498 | ? query_md(const_pd, DNNL_ARG_DIFF_SRC) |
499 | : query_md(const_pd, DNNL_ARG_SRC); |
500 | const auto &wei_md = prb->dir & FLAG_WEI |
501 | ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS) |
502 | : query_md(const_pd, DNNL_ARG_WEIGHTS); |
503 | const auto &bia_md = prb->dir & FLAG_WEI |
504 | ? query_md(const_pd, DNNL_ARG_DIFF_BIAS) |
505 | : query_md(const_pd, DNNL_ARG_BIAS); |
506 | const auto &dst_md = prb->dir & FLAG_BWD |
507 | ? query_md(const_pd, DNNL_ARG_DIFF_DST) |
508 | : query_md(const_pd, DNNL_ARG_DST); |
509 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
510 | |
511 | dnnl_dims_t wei_tr_dims {}; |
512 | dnnl_dim_t wei_ndims = query_md_ndims(wei_md); |
513 | std::copy(query_md_dims(wei_md), query_md_dims(wei_md) + wei_ndims, |
514 | wei_tr_dims); |
515 | |
516 | const bool with_groups = true; |
517 | std::swap(wei_tr_dims[with_groups + 0], wei_tr_dims[with_groups + 1]); |
518 | |
519 | const auto fp = dnnl_f32; |
520 | const auto src_tag = tag::abx; |
521 | const auto wei_tag = tag::abx; |
522 | |
523 | // Use CPU prim as the reference in GPU testing to reduce testing time. |
524 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref; |
525 | SAFE(init_prim_ref(prim_ref, prb), WARN); |
526 | |
527 | const auto &test_engine = get_test_engine(); |
528 | const auto &ref_engine = get_cpu_engine(); |
529 | |
530 | dnn_mem_t src_dt(src_md, test_engine); |
531 | dnn_mem_t wei_dt(wei_md, test_engine); |
532 | dnn_mem_t dst_dt(dst_md, test_engine); |
533 | dnn_mem_t bia_dt(bia_md, test_engine); |
534 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
535 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
536 | std::vector<int> binary_po_args; |
537 | SAFE(binary::setup_binary_po( |
538 | const_pd, binary_po_args, binary_po_dt, binary_po_fp), |
539 | WARN); |
540 | std::vector<dnn_mem_t> prelu_po_fp, prelu_po_dt; |
541 | std::vector<int> prelu_po_args; |
542 | SAFE(prelu::setup_prelu_po( |
543 | const_pd, prelu_po_args, prelu_po_fp, prelu_po_dt), |
544 | WARN); |
545 | |
546 | dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine); |
547 | dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine); |
548 | dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine); |
549 | dnn_mem_t wei_tr_fp(wei_ndims, wei_tr_dims, fp, wei_tag, ref_engine); |
550 | dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine); |
551 | dnn_mem_t scratchpad_fp; |
552 | |
553 | if (prim_ref) |
554 | scratchpad_fp = dnn_mem_t( |
555 | query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine); |
556 | |
557 | dnn_mem_t src_scales_dt, src_scales_fp; |
558 | dnn_mem_t wei_scales_dt, wei_scales_fp; |
559 | dnn_mem_t dst_scales_dt, dst_scales_fp; |
560 | dnn_mem_t src_zp_fp, src_zp_dt; |
561 | dnn_mem_t dst_zp_fp, dst_zp_dt; |
562 | |
563 | /* fill memory + reorders <-> */ |
564 | if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN); |
565 | if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN); |
566 | if (need_wei_init(prb)) { |
567 | SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN); |
568 | SAFE(transpose_data_wei(prb, wei_fp, wei_tr_fp), WARN); |
569 | } |
570 | if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN); |
571 | |
572 | args_t args, ref_args; |
573 | |
574 | if (prb->dir & FLAG_FWD) { |
575 | maybe_prepare_runtime_zero_points_v2(src_zp_dt, src_zp_fp, prb->attr, |
576 | DNNL_ARG_SRC, prb->ic, prb->src_zp); |
577 | maybe_prepare_runtime_zero_points_v2(dst_zp_dt, dst_zp_fp, prb->attr, |
578 | DNNL_ARG_DST, prb->oc, prb->dst_zp); |
579 | const int src_mask = attr_t::get_default_mask( |
580 | prb->attr.scales.get(DNNL_ARG_SRC).policy); |
581 | int wei_mask = attr_t::get_default_mask( |
582 | prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, |
583 | DNNL_ARG_WEIGHTS); |
584 | if (prb->has_groups) wei_mask = (1 << wei_mask) + 1; |
585 | const int dst_mask = attr_t::get_default_mask( |
586 | prb->attr.scales.get(DNNL_ARG_DST).policy); |
587 | maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp, |
588 | prb->attr.scales.get(DNNL_ARG_SRC), |
589 | prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales); |
590 | maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp, |
591 | prb->attr.scales.get(DNNL_ARG_WEIGHTS), |
592 | prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales); |
593 | maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp, |
594 | prb->attr.scales.get(DNNL_ARG_DST), |
595 | prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales); |
596 | |
597 | args.set(DNNL_ARG_SRC, src_dt); |
598 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
599 | args.set(DNNL_ARG_BIAS, bia_dt); |
600 | args.set(DNNL_ARG_DST, dst_dt); |
601 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
602 | args.set(binary_po_args, binary_po_dt); |
603 | args.set(prelu_po_args, prelu_po_dt); |
604 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt); |
605 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt); |
606 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
607 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt); |
608 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
609 | |
610 | SAFE(execute_and_wait(prim, args, res), WARN); |
611 | |
612 | if (is_bench_mode(CORR)) { |
613 | ref_args.set(DNNL_ARG_SRC, src_fp); |
614 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
615 | ref_args.set(DNNL_ARG_BIAS, bia_fp); |
616 | ref_args.set(DNNL_ARG_DST, dst_fp); |
617 | ref_args.set(DNNL_ARG_WEIGHTS_1, wei_tr_fp); // Hack. See ref. |
618 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
619 | ref_args.set(binary_po_args, binary_po_fp); |
620 | ref_args.set(prelu_po_args, prelu_po_fp); |
621 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp); |
622 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp); |
623 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp); |
624 | ref_args.set( |
625 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp); |
626 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp); |
627 | |
628 | check_correctness( |
629 | prb, {DST}, args, ref_args, setup_cmp, res, prim_ref); |
630 | } |
631 | } else if (prb->dir == BWD_D) { |
632 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
633 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
634 | args.set(DNNL_ARG_DIFF_SRC, src_dt); |
635 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
636 | |
637 | SAFE(execute_and_wait(prim, args, res), WARN); |
638 | |
639 | if (is_bench_mode(CORR)) { |
640 | ref_args.set(DNNL_ARG_DIFF_SRC, src_fp); |
641 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
642 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
643 | ref_args.set(DNNL_ARG_WEIGHTS_1, wei_tr_fp); // Hack. See ref. |
644 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
645 | |
646 | check_correctness( |
647 | prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref); |
648 | } |
649 | } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) { |
650 | args.set(DNNL_ARG_SRC, src_dt); |
651 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
652 | args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt); |
653 | args.set(DNNL_ARG_DIFF_BIAS, bia_dt); |
654 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
655 | |
656 | SAFE(execute_and_wait(prim, args, res), WARN); |
657 | |
658 | if (is_bench_mode(CORR)) { |
659 | ref_args.set(DNNL_ARG_SRC, src_fp); |
660 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS_1, wei_tr_fp); // Hack. See ref. |
661 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
662 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp); |
663 | ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp); |
664 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
665 | |
666 | check_correctness( |
667 | prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref); |
668 | } |
669 | } else { |
670 | SAFE(FAIL, CRIT); |
671 | } |
672 | |
673 | return measure_perf(prb->ctx_exe, res, prim, args); |
674 | } |
675 | |
676 | } // namespace deconv |
677 | |