1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21
22#include <random>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "utils/parallel.hpp"
27
28#include "dnnl_common.hpp"
29#include "dnnl_memory.hpp"
30
31#include "sum/sum.hpp"
32
33namespace sum {
34
35dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
36 const prb_t *prb = init_pd_args.prb;
37
38 std::vector<benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> src_d_wrappers(
39 prb->n_inputs());
40
41 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input)
42 src_d_wrappers[i_input] = dnn_mem_t::init_md(prb->ndims,
43 prb->dims.data(), prb->sdt[i_input], prb->stag[i_input]);
44
45 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> dst_d {};
46 if (prb->dtag != tag::undef) {
47 dst_d = dnn_mem_t::init_md(
48 prb->ndims, prb->dims.data(), prb->ddt, prb->dtag);
49 }
50
51 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
52 create_dnnl_attr(prb->attr, attr_args_t()));
53
54 std::vector<dnnl_memory_desc_t> src_d(
55 src_d_wrappers.begin(), src_d_wrappers.end());
56 init_pd_args.is_iterator_supported = false;
57 return dnnl_sum_primitive_desc_create(&init_pd_args.pd, init_pd_args.engine,
58 dst_d, prb->n_inputs(), prb->input_scales.data(), src_d.data(),
59 dnnl_attr);
60}
61
62int fill_src(
63 const prb_t *prb, int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
64
65 const auto nelems = mem_fp.nelems();
66 const auto dt = prb->sdt[input_idx];
67 const int range = 16;
68 const int f_min = dt == dnnl_u8 ? 0 : -range / 2;
69
70 benchdnn_parallel_nd(nelems, [&](int64_t i) {
71 const float gen = ((97 * i) - 17 * input_idx + 101) % range;
72 const float value = (dt == dnnl_bf16 || dt == dnnl_f16)
73 ? (f_min + gen) / range
74 : (f_min + gen) * (1.0f + 4.0f / range);
75 mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
76 });
77
78 SAFE(mem_dt.reorder(mem_fp), WARN);
79
80 return OK;
81}
82
83void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
84 std::vector<dnnl_data_type_t> dts = prb->sdt;
85 dts.push_back(prb->ddt);
86 skip_unimplemented_data_type(dts, prb->dir, res);
87 skip_unimplemented_sum_po(prb->attr, res);
88}
89
90void skip_invalid_prb(const prb_t *prb, res_t *res) {
91 // See `skip_invalid_inplace` for details.
92 if (prb->inplace) {
93 skip_invalid_inplace(
94 res, prb->sdt[0], prb->ddt, prb->stag[0], prb->dtag);
95 if (res->state == SKIPPED) return;
96 }
97}
98
99void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
100 const args_t &ref_args) {
101 cmp.set_threshold(epsilon_dt(prb->ddt) * prb->n_inputs());
102}
103
104int doit(const prb_t *prb, res_t *res) {
105 if (bench_mode == LIST) return res->state = LISTED, OK;
106
107 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
108 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
109 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
110
111 auto const_pd = query_pd(prim);
112
113 const auto &test_engine = get_test_engine();
114 const auto &ref_engine = get_cpu_engine();
115 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
116 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
117
118 std::vector<dnn_mem_t> src_fp, src_dt;
119 src_fp.reserve(prb->n_inputs());
120 src_dt.reserve(prb->n_inputs());
121
122 args_t args, ref_args;
123 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
124 const auto &src_md
125 = query_md(const_pd, DNNL_ARG_MULTIPLE_SRC + i_input);
126 src_fp.emplace_back(src_md, dnnl_f32, tag::abx, ref_engine);
127 src_dt.emplace_back(src_md, test_engine);
128 SAFE(fill_src(prb, i_input, src_dt[i_input], src_fp[i_input]), WARN);
129 args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]);
130 if (is_bench_mode(CORR))
131 ref_args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_fp[i_input]);
132 }
133 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine);
134 dnn_mem_t placeholder_dst_dt;
135
136 if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); }
137 dnn_mem_t &dst_dt = prb->inplace ? src_dt[0] : placeholder_dst_dt;
138 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
139
140 args.set(DNNL_ARG_DST, dst_dt);
141 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
142
143 SAFE(execute_and_wait(prim, args, res), WARN);
144
145 if (is_bench_mode(CORR)) {
146 ref_args.set(DNNL_ARG_DST, dst_fp);
147
148 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
149 }
150
151 return measure_perf(prb->ctx_exe, res, prim, args);
152}
153
154} // namespace sum
155