1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18#include <random>
19#include <stdio.h>
20#include <stdlib.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "utils/parallel.hpp"
25
26#include "dnnl_common.hpp"
27#include "dnnl_memory.hpp"
28
29#include "prelu/prelu.hpp"
30
31namespace prelu {
32
33int fill_data(data_kind_t kind, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
34 const auto nelems = mem_fp.nelems();
35 if (nelems == 0) return OK;
36
37 // Do fixed partitioning to have same filling for any number of threads.
38 const int64_t n_chunks = 16;
39 const int64_t chunk_size = div_up(nelems, n_chunks);
40
41 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
42 int64_t idx_start = idx_chunk * chunk_size;
43 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
44 // Note 1: we use a different seed for each chunk to avoid
45 // repeating patterns. We could use discard(idx_start) too but
46 // we avoid it for two reasons:
47 // a. it has a complexity in O(idx_start).
48 // b. igen below might require more than 1 sample
49 // per idx, so the we cannot deterministically compute the
50 // number of states we need to discard
51 // Note 2: We also advance the state to avoid having only
52 // small values as first chunk input. The +1 is necessary to
53 // avoid generating zeros in first chunk.
54 // Note 3: we multiply by kind + 1 to have different values in
55 // src/dst and diff_dst. The +1 is to avoid 0 again.
56 std::minstd_rand msr((idx_start + 1) * (kind + 1));
57 msr.discard(1);
58 std::uniform_int_distribution<> igen_02(0, 2), igen_05(0, 5),
59 igen_06(0, 6);
60 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
61 float value = 0;
62 if (is_integral_dt(mem_dt.dt())) {
63 value = igen_05(msr);
64 } else {
65 // TODO: amount of negative values should depend on number of points
66 // to reduce as summation becomes inaccurate.
67 switch (kind) {
68 case SRC: value = igen_02(msr); break;
69 case WEI:
70 value = (64 >> igen_06(msr)) / 8.f; // pow2 [0.125f, 8f]
71 break;
72 case DST: value = igen_02(msr) / 16.f; break;
73 default: assert(!"unexpected"); break;
74 }
75 }
76 float sign = mem_dt.dt() == dnnl_u8
77 ? 1.f
78 : flip_coin(idx, 0.1f) ? -1.f : 1.f;
79 value = round_to_nearest_representable(mem_dt.dt(), sign * value);
80 mem_fp.set_elem(idx, value);
81 }
82 });
83
84 SAFE(mem_dt.reorder(mem_fp), WARN);
85
86 return OK;
87}
88
89int setup_prelu_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args,
90 std::vector<dnn_mem_t> &ref_mem, std::vector<dnn_mem_t> &prim_mem) {
91 const auto &dst_md = query_md(pd, DNNL_ARG_DST);
92 auto const_attr_po = query_post_ops(pd);
93 const auto po_len = dnnl_post_ops_len(const_attr_po);
94 for (int idx = 0; idx < po_len; ++idx) {
95 const auto kind = dnnl_post_ops_get_kind(const_attr_po, idx);
96 if (kind != dnnl_prelu) continue;
97
98 const auto ndims = query_md_ndims(dst_md);
99 int mask = 0;
100 dnnl_dims_t dims = {0};
101 dnnl_post_ops_get_params_prelu(const_attr_po, idx, &mask);
102
103 // Deduce prelu weights dims based on input policy.
104 for (int d = 0; d < ndims; ++d) {
105 dims[d] = (mask & (1 << d)) ? query_md_dims(dst_md)[d] : 1;
106 }
107
108 // Following call can not be executed if po_md has runtime dimension due
109 // to undefined size.
110 ref_mem.emplace_back(ndims, dims, dnnl_f32, tag::abx, get_cpu_engine());
111 prim_mem.emplace_back(
112 ndims, dims, dnnl_f32, tag::axb, get_test_engine());
113 args.push_back(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_WEIGHTS);
114 fill_data(WEI, prim_mem.back(), ref_mem.back());
115 }
116 return OK;
117}
118
119dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
120 const prb_t *prb = init_pd_args.prb;
121
122 const auto &src_dims = prb->vdims[0];
123 const auto &weight_dims = prb->vdims[1];
124
125 auto src_d = dnn_mem_t::init_md(
126 prb->ndims, src_dims.data(), prb->sdt[0], prb->stag[0]);
127 auto weights_d = dnn_mem_t::init_md(
128 prb->ndims, weight_dims.data(), prb->sdt[1], prb->stag[1]);
129
130 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
131 create_dnnl_attr(prb->attr, attr_args_t()));
132
133 if (prb->dir & FLAG_FWD) {
134 auto dst_d = dnn_mem_t::init_md(
135 prb->ndims, src_dims.data(), prb->sdt[0], tag::any);
136
137 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
138 : dnnl_forward_training;
139 DNN_SAFE_STATUS(dnnl_prelu_forward_primitive_desc_create(
140 &init_pd_args.pd, init_pd_args.engine, prop, src_d, weights_d,
141 dst_d, dnnl_attr));
142 } else {
143 auto diff_src_d = dnn_mem_t::init_md(
144 prb->ndims, src_dims.data(), prb->sdt[0], tag::any);
145 auto diff_weights_d = dnn_mem_t::init_md(
146 prb->ndims, weight_dims.data(), prb->sdt[1], tag::any);
147 auto diff_dst_d = dnn_mem_t::init_md(
148 prb->ndims, src_dims.data(), prb->sdt[0], tag::any);
149
150 DNN_SAFE_STATUS(dnnl_prelu_backward_primitive_desc_create(
151 &init_pd_args.pd, init_pd_args.engine, src_d, weights_d,
152 diff_src_d, diff_weights_d, diff_dst_d, init_pd_args.hint,
153 dnnl_attr));
154 }
155
156 return dnnl_success;
157}
158
159void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
160 skip_unimplemented_data_type(prb->sdt, FWD_D, res);
161 skip_unimplemented_sum_po(prb->attr, res);
162}
163
164void skip_invalid_prb(const prb_t *prb, res_t *res) {}
165
166void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
167 const args_t &ref_args) {
168 const auto trh_dt = kind == WEI ? prb->sdt[1] : prb->sdt[0];
169 cmp.set_threshold(2 * epsilon_dt(trh_dt));
170
171 // Weights are very sparse, no sense to test for trust, otherwise filling
172 // is specific to cover half non-zeros only.
173 const float zero_trust_percent = kind == WEI ? 99.f : 50.f;
174 cmp.set_zero_trust_percent(zero_trust_percent);
175}
176
177int doit(const prb_t *prb, res_t *res) {
178 if (bench_mode == LIST) return res->state = LISTED, OK;
179
180 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
181 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
182 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
183
184 auto const_pd = query_pd(prim);
185
186 const auto &data_md = query_md(const_pd, DNNL_ARG_SRC);
187 const auto &weight_md = query_md(const_pd, DNNL_ARG_WEIGHTS);
188 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
189 const auto &test_engine = get_test_engine();
190 const auto &ref_engine = get_cpu_engine();
191
192 dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, ref_engine);
193 dnn_mem_t weights_fp(weight_md, dnnl_f32, tag::abx, ref_engine);
194
195 dnn_mem_t src_dt(data_md, test_engine);
196 dnn_mem_t weights_dt(weight_md, test_engine);
197 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
198
199 SAFE(fill_data(SRC, src_dt, src_fp), WARN);
200 SAFE(fill_data(WEI, weights_dt, weights_fp), WARN);
201
202 args_t args, ref_args;
203
204 args.set(DNNL_ARG_SRC, src_dt);
205 args.set(DNNL_ARG_WEIGHTS, weights_dt);
206 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
207
208 dnn_mem_t dst_dt, d_src_fp, d_src_dt, d_dst_fp, d_dst_dt, d_weights_fp,
209 d_weights_dt;
210
211 if (prb->dir & FLAG_FWD) {
212 dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, ref_engine);
213 dst_dt = dnn_mem_t(data_md, test_engine);
214
215 args.set(DNNL_ARG_DST, dst_dt);
216
217 SAFE(execute_and_wait(prim, args, res), WARN);
218
219 if (is_bench_mode(CORR)) {
220 ref_args.set(DNNL_ARG_SRC, src_fp);
221 ref_args.set(DNNL_ARG_WEIGHTS, weights_fp);
222 ref_args.set(DNNL_ARG_DST, dst_fp);
223
224 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
225 }
226 } else {
227 const auto &d_data_md = query_md(const_pd, DNNL_ARG_DIFF_DST);
228 const auto &d_weights_md = query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS);
229
230 dnn_mem_t d_src_fp(d_data_md, dnnl_f32, tag::abx, ref_engine);
231 dnn_mem_t d_weights_fp(d_weights_md, dnnl_f32, tag::abx, ref_engine);
232 dnn_mem_t d_dst_fp(d_data_md, dnnl_f32, tag::abx, ref_engine);
233
234 d_src_dt = dnn_mem_t(d_data_md, test_engine);
235 d_weights_dt = dnn_mem_t(d_weights_md, test_engine);
236 d_dst_dt = dnn_mem_t(d_data_md, test_engine);
237
238 SAFE(fill_data(DST, d_dst_dt, d_dst_fp), WARN);
239
240 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
241 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
242 args.set(DNNL_ARG_DIFF_WEIGHTS, d_weights_dt);
243
244 SAFE(execute_and_wait(prim, args, res), WARN);
245
246 if (is_bench_mode(CORR)) {
247 ref_args.set(DNNL_ARG_SRC, src_fp);
248 ref_args.set(DNNL_ARG_WEIGHTS, weights_fp);
249 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
250 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
251 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, d_weights_fp);
252
253 check_correctness(prb, {SRC, WEI}, args, ref_args, setup_cmp, res);
254 }
255 }
256
257 return measure_perf(prb->ctx_exe, res, prim, args);
258}
259
260} // namespace prelu
261