1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | #include <random> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "utils/parallel.hpp" |
25 | |
26 | #include "dnnl_common.hpp" |
27 | #include "dnnl_memory.hpp" |
28 | |
29 | #include "prelu/prelu.hpp" |
30 | |
31 | namespace prelu { |
32 | |
33 | int fill_data(data_kind_t kind, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
34 | const auto nelems = mem_fp.nelems(); |
35 | if (nelems == 0) return OK; |
36 | |
37 | // Do fixed partitioning to have same filling for any number of threads. |
38 | const int64_t n_chunks = 16; |
39 | const int64_t chunk_size = div_up(nelems, n_chunks); |
40 | |
41 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
42 | int64_t idx_start = idx_chunk * chunk_size; |
43 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
44 | // Note 1: we use a different seed for each chunk to avoid |
45 | // repeating patterns. We could use discard(idx_start) too but |
46 | // we avoid it for two reasons: |
47 | // a. it has a complexity in O(idx_start). |
48 | // b. igen below might require more than 1 sample |
49 | // per idx, so the we cannot deterministically compute the |
50 | // number of states we need to discard |
51 | // Note 2: We also advance the state to avoid having only |
52 | // small values as first chunk input. The +1 is necessary to |
53 | // avoid generating zeros in first chunk. |
54 | // Note 3: we multiply by kind + 1 to have different values in |
55 | // src/dst and diff_dst. The +1 is to avoid 0 again. |
56 | std::minstd_rand msr((idx_start + 1) * (kind + 1)); |
57 | msr.discard(1); |
58 | std::uniform_int_distribution<> igen_02(0, 2), igen_05(0, 5), |
59 | igen_06(0, 6); |
60 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
61 | float value = 0; |
62 | if (is_integral_dt(mem_dt.dt())) { |
63 | value = igen_05(msr); |
64 | } else { |
65 | // TODO: amount of negative values should depend on number of points |
66 | // to reduce as summation becomes inaccurate. |
67 | switch (kind) { |
68 | case SRC: value = igen_02(msr); break; |
69 | case WEI: |
70 | value = (64 >> igen_06(msr)) / 8.f; // pow2 [0.125f, 8f] |
71 | break; |
72 | case DST: value = igen_02(msr) / 16.f; break; |
73 | default: assert(!"unexpected" ); break; |
74 | } |
75 | } |
76 | float sign = mem_dt.dt() == dnnl_u8 |
77 | ? 1.f |
78 | : flip_coin(idx, 0.1f) ? -1.f : 1.f; |
79 | value = round_to_nearest_representable(mem_dt.dt(), sign * value); |
80 | mem_fp.set_elem(idx, value); |
81 | } |
82 | }); |
83 | |
84 | SAFE(mem_dt.reorder(mem_fp), WARN); |
85 | |
86 | return OK; |
87 | } |
88 | |
89 | int setup_prelu_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args, |
90 | std::vector<dnn_mem_t> &ref_mem, std::vector<dnn_mem_t> &prim_mem) { |
91 | const auto &dst_md = query_md(pd, DNNL_ARG_DST); |
92 | auto const_attr_po = query_post_ops(pd); |
93 | const auto po_len = dnnl_post_ops_len(const_attr_po); |
94 | for (int idx = 0; idx < po_len; ++idx) { |
95 | const auto kind = dnnl_post_ops_get_kind(const_attr_po, idx); |
96 | if (kind != dnnl_prelu) continue; |
97 | |
98 | const auto ndims = query_md_ndims(dst_md); |
99 | int mask = 0; |
100 | dnnl_dims_t dims = {0}; |
101 | dnnl_post_ops_get_params_prelu(const_attr_po, idx, &mask); |
102 | |
103 | // Deduce prelu weights dims based on input policy. |
104 | for (int d = 0; d < ndims; ++d) { |
105 | dims[d] = (mask & (1 << d)) ? query_md_dims(dst_md)[d] : 1; |
106 | } |
107 | |
108 | // Following call can not be executed if po_md has runtime dimension due |
109 | // to undefined size. |
110 | ref_mem.emplace_back(ndims, dims, dnnl_f32, tag::abx, get_cpu_engine()); |
111 | prim_mem.emplace_back( |
112 | ndims, dims, dnnl_f32, tag::axb, get_test_engine()); |
113 | args.push_back(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_WEIGHTS); |
114 | fill_data(WEI, prim_mem.back(), ref_mem.back()); |
115 | } |
116 | return OK; |
117 | } |
118 | |
119 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
120 | const prb_t *prb = init_pd_args.prb; |
121 | |
122 | const auto &src_dims = prb->vdims[0]; |
123 | const auto &weight_dims = prb->vdims[1]; |
124 | |
125 | auto src_d = dnn_mem_t::init_md( |
126 | prb->ndims, src_dims.data(), prb->sdt[0], prb->stag[0]); |
127 | auto weights_d = dnn_mem_t::init_md( |
128 | prb->ndims, weight_dims.data(), prb->sdt[1], prb->stag[1]); |
129 | |
130 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
131 | create_dnnl_attr(prb->attr, attr_args_t())); |
132 | |
133 | if (prb->dir & FLAG_FWD) { |
134 | auto dst_d = dnn_mem_t::init_md( |
135 | prb->ndims, src_dims.data(), prb->sdt[0], tag::any); |
136 | |
137 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
138 | : dnnl_forward_training; |
139 | DNN_SAFE_STATUS(dnnl_prelu_forward_primitive_desc_create( |
140 | &init_pd_args.pd, init_pd_args.engine, prop, src_d, weights_d, |
141 | dst_d, dnnl_attr)); |
142 | } else { |
143 | auto diff_src_d = dnn_mem_t::init_md( |
144 | prb->ndims, src_dims.data(), prb->sdt[0], tag::any); |
145 | auto diff_weights_d = dnn_mem_t::init_md( |
146 | prb->ndims, weight_dims.data(), prb->sdt[1], tag::any); |
147 | auto diff_dst_d = dnn_mem_t::init_md( |
148 | prb->ndims, src_dims.data(), prb->sdt[0], tag::any); |
149 | |
150 | DNN_SAFE_STATUS(dnnl_prelu_backward_primitive_desc_create( |
151 | &init_pd_args.pd, init_pd_args.engine, src_d, weights_d, |
152 | diff_src_d, diff_weights_d, diff_dst_d, init_pd_args.hint, |
153 | dnnl_attr)); |
154 | } |
155 | |
156 | return dnnl_success; |
157 | } |
158 | |
159 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
160 | skip_unimplemented_data_type(prb->sdt, FWD_D, res); |
161 | skip_unimplemented_sum_po(prb->attr, res); |
162 | } |
163 | |
164 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
165 | |
166 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
167 | const args_t &ref_args) { |
168 | const auto trh_dt = kind == WEI ? prb->sdt[1] : prb->sdt[0]; |
169 | cmp.set_threshold(2 * epsilon_dt(trh_dt)); |
170 | |
171 | // Weights are very sparse, no sense to test for trust, otherwise filling |
172 | // is specific to cover half non-zeros only. |
173 | const float zero_trust_percent = kind == WEI ? 99.f : 50.f; |
174 | cmp.set_zero_trust_percent(zero_trust_percent); |
175 | } |
176 | |
177 | int doit(const prb_t *prb, res_t *res) { |
178 | if (bench_mode == LIST) return res->state = LISTED, OK; |
179 | |
180 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
181 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
182 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
183 | |
184 | auto const_pd = query_pd(prim); |
185 | |
186 | const auto &data_md = query_md(const_pd, DNNL_ARG_SRC); |
187 | const auto &weight_md = query_md(const_pd, DNNL_ARG_WEIGHTS); |
188 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
189 | const auto &test_engine = get_test_engine(); |
190 | const auto &ref_engine = get_cpu_engine(); |
191 | |
192 | dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, ref_engine); |
193 | dnn_mem_t weights_fp(weight_md, dnnl_f32, tag::abx, ref_engine); |
194 | |
195 | dnn_mem_t src_dt(data_md, test_engine); |
196 | dnn_mem_t weights_dt(weight_md, test_engine); |
197 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
198 | |
199 | SAFE(fill_data(SRC, src_dt, src_fp), WARN); |
200 | SAFE(fill_data(WEI, weights_dt, weights_fp), WARN); |
201 | |
202 | args_t args, ref_args; |
203 | |
204 | args.set(DNNL_ARG_SRC, src_dt); |
205 | args.set(DNNL_ARG_WEIGHTS, weights_dt); |
206 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
207 | |
208 | dnn_mem_t dst_dt, d_src_fp, d_src_dt, d_dst_fp, d_dst_dt, d_weights_fp, |
209 | d_weights_dt; |
210 | |
211 | if (prb->dir & FLAG_FWD) { |
212 | dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, ref_engine); |
213 | dst_dt = dnn_mem_t(data_md, test_engine); |
214 | |
215 | args.set(DNNL_ARG_DST, dst_dt); |
216 | |
217 | SAFE(execute_and_wait(prim, args, res), WARN); |
218 | |
219 | if (is_bench_mode(CORR)) { |
220 | ref_args.set(DNNL_ARG_SRC, src_fp); |
221 | ref_args.set(DNNL_ARG_WEIGHTS, weights_fp); |
222 | ref_args.set(DNNL_ARG_DST, dst_fp); |
223 | |
224 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
225 | } |
226 | } else { |
227 | const auto &d_data_md = query_md(const_pd, DNNL_ARG_DIFF_DST); |
228 | const auto &d_weights_md = query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS); |
229 | |
230 | dnn_mem_t d_src_fp(d_data_md, dnnl_f32, tag::abx, ref_engine); |
231 | dnn_mem_t d_weights_fp(d_weights_md, dnnl_f32, tag::abx, ref_engine); |
232 | dnn_mem_t d_dst_fp(d_data_md, dnnl_f32, tag::abx, ref_engine); |
233 | |
234 | d_src_dt = dnn_mem_t(d_data_md, test_engine); |
235 | d_weights_dt = dnn_mem_t(d_weights_md, test_engine); |
236 | d_dst_dt = dnn_mem_t(d_data_md, test_engine); |
237 | |
238 | SAFE(fill_data(DST, d_dst_dt, d_dst_fp), WARN); |
239 | |
240 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
241 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
242 | args.set(DNNL_ARG_DIFF_WEIGHTS, d_weights_dt); |
243 | |
244 | SAFE(execute_and_wait(prim, args, res), WARN); |
245 | |
246 | if (is_bench_mode(CORR)) { |
247 | ref_args.set(DNNL_ARG_SRC, src_fp); |
248 | ref_args.set(DNNL_ARG_WEIGHTS, weights_fp); |
249 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
250 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
251 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS, d_weights_fp); |
252 | |
253 | check_correctness(prb, {SRC, WEI}, args, ref_args, setup_cmp, res); |
254 | } |
255 | } |
256 | |
257 | return measure_perf(prb->ctx_exe, res, prim, args); |
258 | } |
259 | |
260 | } // namespace prelu |
261 | |