1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18
19#include <random>
20#include <sstream>
21
22#include "utils/parallel.hpp"
23
24#include "dnnl_common.hpp"
25#include "dnnl_memory.hpp"
26
27#include "binary/binary.hpp"
28#include "reduction/reduction.hpp"
29
30namespace reduction {
31
32dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
33 const prb_t *prb = init_pd_args.prb;
34
35 auto src_desc = dnn_mem_t::init_md(
36 prb->ndims, prb->vdims[0].data(), prb->sdt, prb->stag);
37 auto dst_desc = dnn_mem_t::init_md(
38 prb->ndims, prb->vdims[1].data(), prb->ddt, prb->dtag);
39
40 attr_args_t attr_args;
41 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->vdims[1].data());
42 const auto dnnl_attr = make_benchdnn_dnnl_wrapper(
43 create_dnnl_attr(prb->attr, attr_args));
44
45 DNN_SAFE_STATUS(dnnl_reduction_primitive_desc_create(&init_pd_args.pd,
46 init_pd_args.engine, alg2alg_kind(prb->alg), src_desc, dst_desc,
47 prb->p, prb->eps, dnnl_attr));
48
49 return dnnl_success;
50}
51
52bool is_norm_alg(const alg_t alg) {
53 return alg == alg_t::norm_lp_max || alg == alg_t::norm_lp_sum
54 || alg == alg_t::norm_lp_power_p_max
55 || alg == alg_t::norm_lp_power_p_sum;
56}
57
58int fill_mem(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
59 float non_neutral_prob, bool use_reduced_range,
60 bool only_positive_values) {
61 const auto sdt = mem_dt.dt();
62 const auto nelems = mem_fp.nelems();
63 const float neutral_value = prb->alg == alg_t::mul ? 1.0f : 0.0f;
64 const float mean_shift = prb->alg == alg_t::mean ? 1.0f : 0.0f;
65 const bool is_signed = sdt != dnnl_u8;
66 const bool is_int = is_integral_dt(sdt);
67
68 int value_range = use_reduced_range ? 16 : 1000;
69 if (is_int) value_range = use_reduced_range ? 3 : max_dt(dnnl_s8);
70
71 const int64_t n_chunks = 16;
72 const int64_t chunk_size = div_up(nelems, n_chunks);
73
74 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
75 const int64_t idx_start = idx_chunk * chunk_size;
76 const int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
77
78 std::minstd_rand msr(idx_start + 1);
79 msr.discard(1);
80 std::uniform_int_distribution<> igen(1, value_range);
81
82 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
83 float value = neutral_value;
84 if (flip_coin(idx, non_neutral_prob)) {
85 const int gen = igen(msr);
86 value = is_int ? gen : gen / 8.f;
87 if (!only_positive_values && is_signed && flip_coin(gen, 0.5f))
88 value = -value;
89 }
90 value += mean_shift;
91 mem_fp.set_elem(idx, round_to_nearest_representable(sdt, value));
92 }
93 });
94 SAFE(mem_dt.reorder(mem_fp), WARN);
95 return OK;
96}
97
98int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
99 const auto nelems = mem_fp.nelems();
100 const auto ddt = prb->ddt;
101 if (!nelems) return OK;
102
103 int nelems_to_reduce = 1;
104 for (int dim = 0; dim < prb->ndims; dim++) {
105 if (prb->vdims[0][dim] != prb->vdims[1][dim]) {
106 nelems_to_reduce *= prb->vdims[0][dim];
107 }
108 }
109 // There is no accumulation error in case of min or max algorithm
110 const bool is_min_or_max = prb->alg == alg_t::min || prb->alg == alg_t::max;
111 // Number of elements that should not exceed datatype limit after reduction
112 int safe_to_reduce_elems = nelems_to_reduce;
113 if (!is_min_or_max) { // Other algs do computations, reduce final values
114 safe_to_reduce_elems = prb->alg == alg_t::mul ? 10 : 1000;
115 // Integral values easily reach border values,
116 // shrink their final values more
117 if (is_integral_dt(ddt))
118 safe_to_reduce_elems = prb->alg == alg_t::mul ? 3 : 10;
119 }
120 const float non_neutral_prob
121 = 1.f * safe_to_reduce_elems / nelems_to_reduce;
122
123 return fill_mem(
124 prb, mem_dt, mem_fp, non_neutral_prob, !is_min_or_max, false);
125}
126
127int fill_dst(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
128 const bool only_positive_values = is_norm_alg(prb->alg);
129 return fill_mem(prb, mem_dt, mem_fp, 1.0f, false, only_positive_values);
130}
131
132void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
133 skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res);
134 skip_unimplemented_sum_po(prb->attr, res);
135}
136
137void skip_invalid_prb(const prb_t *prb, res_t *res) {
138 // Normalization algorithms don't make sense for integer data type.
139 // They also can't have `p` parameter less than one.
140 const bool is_invalid = is_norm_alg(prb->alg)
141 && (is_integral_dt(prb->sdt) || prb->p < 1.f);
142
143 if (is_invalid) {
144 res->state = SKIPPED, res->reason = INVALID_CASE;
145 return;
146 }
147}
148
149void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
150 const args_t &ref_args) {
151 // `5` is a temporary magic const for GPU to pass norm algs.
152 // TODO: consider change the filling with power-of-two values for better
153 // answer precision.
154 cmp.set_threshold(5 * epsilon_dt(prb->ddt));
155 if (is_amd_gpu()) {
156 // MIOpen implementation is less accurate for f16 data type therefore
157 // adjust the threshold.
158 if (prb->sdt == dnnl_f16 || prb->ddt == dnnl_f16)
159 cmp.set_threshold(1.5e-4 * 4);
160 }
161}
162
163int doit(const prb_t *prb, res_t *res) {
164 if (bench_mode == LIST) return res->state = LISTED, OK;
165
166 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
167 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
168 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
169
170 auto const_pd = query_pd(prim);
171
172 const auto fp_dt = dnnl_f32;
173 const auto abx_tag = tag::abx;
174
175 const auto &test_engine = get_test_engine();
176 const auto &ref_engine = get_cpu_engine();
177
178 const auto &src_md = query_md(const_pd, DNNL_ARG_SRC);
179 dnn_mem_t src_fp(src_md, fp_dt, abx_tag, ref_engine);
180 dnn_mem_t src_dt(src_md, test_engine);
181 SAFE(fill_src(prb, src_dt, src_fp), WARN);
182
183 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
184 dnn_mem_t dst_fp(dst_md, fp_dt, abx_tag, ref_engine);
185 dnn_mem_t dst_dt(dst_md, test_engine);
186 if (prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) >= 0)
187 SAFE(fill_dst(prb, dst_dt, dst_fp), WARN);
188
189 const bool binary_po_only_positive_vals = is_norm_alg(prb->alg);
190 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
191 std::vector<int> binary_po_args;
192 SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
193 binary_po_fp, binary_po_only_positive_vals),
194 WARN);
195
196 args_t args, ref_args;
197
198 args.set(DNNL_ARG_SRC, src_dt);
199 args.set(DNNL_ARG_DST, dst_dt);
200 args.set(binary_po_args, binary_po_dt);
201
202 SAFE(execute_and_wait(prim, args, res), WARN);
203
204 if (is_bench_mode(CORR)) {
205 ref_args.set(DNNL_ARG_SRC, src_fp);
206 ref_args.set(DNNL_ARG_DST, dst_fp);
207 ref_args.set(binary_po_args, binary_po_fp);
208
209 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
210 }
211
212 return measure_perf(prb->ctx_exe, res, prim, args);
213}
214
215} // namespace reduction
216