1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include <random> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "utils/parallel.hpp" |
27 | |
28 | #include "dnnl_common.hpp" |
29 | #include "dnnl_memory.hpp" |
30 | |
31 | #include "sum/sum.hpp" |
32 | |
33 | namespace sum { |
34 | |
35 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
36 | const prb_t *prb = init_pd_args.prb; |
37 | |
38 | std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d_wrappers( |
39 | prb->n_inputs()); |
40 | |
41 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) |
42 | src_d_wrappers[i_input] = dnn_mem_t::init_md(prb->ndims, |
43 | prb->dims.data(), prb->sdt[i_input], prb->stag[i_input]); |
44 | |
45 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> dst_d {}; |
46 | if (prb->dtag != tag::undef) { |
47 | dst_d = dnn_mem_t::init_md( |
48 | prb->ndims, prb->dims.data(), prb->ddt, prb->dtag); |
49 | } |
50 | |
51 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
52 | create_dnnl_attr(prb->attr, attr_args_t())); |
53 | |
54 | std::vector<dnnl_memory_desc_t> src_d( |
55 | src_d_wrappers.begin(), src_d_wrappers.end()); |
56 | init_pd_args.is_iterator_supported = false; |
57 | return dnnl_sum_primitive_desc_create(&init_pd_args.pd, init_pd_args.engine, |
58 | dst_d, prb->n_inputs(), prb->input_scales.data(), src_d.data(), |
59 | dnnl_attr); |
60 | } |
61 | |
62 | int fill_src( |
63 | const prb_t *prb, int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
64 | |
65 | const auto nelems = mem_fp.nelems(); |
66 | const auto dt = prb->sdt[input_idx]; |
67 | const int range = 16; |
68 | const int f_min = dt == dnnl_u8 ? 0 : -range / 2; |
69 | |
70 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
71 | const float gen = ((97 * i) - 17 * input_idx + 101) % range; |
72 | const float value = (dt == dnnl_bf16 || dt == dnnl_f16) |
73 | ? (f_min + gen) / range |
74 | : (f_min + gen) * (1.0f + 4.0f / range); |
75 | mem_fp.set_elem(i, round_to_nearest_representable(dt, value)); |
76 | }); |
77 | |
78 | SAFE(mem_dt.reorder(mem_fp), WARN); |
79 | |
80 | return OK; |
81 | } |
82 | |
83 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
84 | std::vector<dnnl_data_type_t> dts = prb->sdt; |
85 | dts.push_back(prb->ddt); |
86 | skip_unimplemented_data_type(dts, prb->dir, res); |
87 | skip_unimplemented_sum_po(prb->attr, res); |
88 | } |
89 | |
90 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
91 | // See `skip_invalid_inplace` for details. |
92 | if (prb->inplace) { |
93 | skip_invalid_inplace( |
94 | res, prb->sdt[0], prb->ddt, prb->stag[0], prb->dtag); |
95 | if (res->state == SKIPPED) return; |
96 | } |
97 | } |
98 | |
99 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
100 | const args_t &ref_args) { |
101 | cmp.set_threshold(epsilon_dt(prb->ddt) * prb->n_inputs()); |
102 | } |
103 | |
104 | int doit(const prb_t *prb, res_t *res) { |
105 | if (bench_mode == LIST) return res->state = LISTED, OK; |
106 | |
107 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
108 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
109 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
110 | |
111 | auto const_pd = query_pd(prim); |
112 | |
113 | const auto &test_engine = get_test_engine(); |
114 | const auto &ref_engine = get_cpu_engine(); |
115 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
116 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
117 | |
118 | std::vector<dnn_mem_t> src_fp, src_dt; |
119 | src_fp.reserve(prb->n_inputs()); |
120 | src_dt.reserve(prb->n_inputs()); |
121 | |
122 | args_t args, ref_args; |
123 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) { |
124 | const auto &src_md |
125 | = query_md(const_pd, DNNL_ARG_MULTIPLE_SRC + i_input); |
126 | src_fp.emplace_back(src_md, dnnl_f32, tag::abx, ref_engine); |
127 | src_dt.emplace_back(src_md, test_engine); |
128 | SAFE(fill_src(prb, i_input, src_dt[i_input], src_fp[i_input]), WARN); |
129 | args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]); |
130 | if (is_bench_mode(CORR)) |
131 | ref_args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_fp[i_input]); |
132 | } |
133 | dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine); |
134 | dnn_mem_t placeholder_dst_dt; |
135 | |
136 | if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); } |
137 | dnn_mem_t &dst_dt = prb->inplace ? src_dt[0] : placeholder_dst_dt; |
138 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
139 | |
140 | args.set(DNNL_ARG_DST, dst_dt); |
141 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
142 | |
143 | SAFE(execute_and_wait(prim, args, res), WARN); |
144 | |
145 | if (is_bench_mode(CORR)) { |
146 | ref_args.set(DNNL_ARG_DST, dst_fp); |
147 | |
148 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
149 | } |
150 | |
151 | return measure_perf(prb->ctx_exe, res, prim, args); |
152 | } |
153 | |
154 | } // namespace sum |
155 | |