1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <random> |
18 | |
19 | #include <float.h> |
20 | #include <math.h> |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "utils/parallel.hpp" |
27 | |
28 | #include "dnnl_common.hpp" |
29 | #include "dnnl_memory.hpp" |
30 | |
31 | #include "softmax/softmax.hpp" |
32 | |
33 | namespace softmax { |
34 | |
35 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
36 | const prb_t *prb = init_pd_args.prb; |
37 | |
38 | auto dst_d = dnn_mem_t::init_md( |
39 | prb->ndims, prb->dims.data(), prb->ddt, prb->dtag); |
40 | |
41 | dnnl_alg_kind_t alg_kind = dnnl_softmax_accurate; |
42 | if (prb->alg == LOGSOFTMAX) alg_kind = dnnl_softmax_log; |
43 | |
44 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
45 | create_dnnl_attr(prb->attr, attr_args_t())); |
46 | |
47 | if (prb->dir & FLAG_FWD) { |
48 | auto src_d = dnn_mem_t::init_md( |
49 | prb->ndims, prb->dims.data(), prb->sdt, prb->stag); |
50 | |
51 | auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference |
52 | : dnnl_forward_training; |
53 | |
54 | DNN_SAFE_STATUS(dnnl_softmax_forward_primitive_desc_create( |
55 | &init_pd_args.pd, init_pd_args.engine, prop, alg_kind, src_d, |
56 | dst_d, prb->axis, dnnl_attr)); |
57 | } else { |
58 | // Re-create dst_md with source tag if dst was not specified, immitating |
59 | // default value. |
60 | if (prb->dtag == tag::any) { |
61 | dst_d = dnn_mem_t::init_md( |
62 | prb->ndims, prb->dims.data(), prb->ddt, prb->stag); |
63 | } |
64 | |
65 | auto diff_src_d = dnn_mem_t::init_md( |
66 | prb->ndims, prb->dims.data(), prb->sdt, tag::any); |
67 | auto diff_dst_d = dnn_mem_t::init_md( |
68 | prb->ndims, prb->dims.data(), prb->ddt, tag::any); |
69 | |
70 | DNN_SAFE_STATUS(dnnl_softmax_backward_primitive_desc_create( |
71 | &init_pd_args.pd, init_pd_args.engine, alg_kind, diff_src_d, |
72 | diff_dst_d, dst_d, prb->axis, init_pd_args.hint, dnnl_attr)); |
73 | } |
74 | |
75 | return dnnl_success; |
76 | } |
77 | |
78 | int fill_data_fwd(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
79 | int64_t outer_size = 0, inner_size = 0, axis_size = 0; |
80 | get_sizes(prb, outer_size, inner_size, axis_size); |
81 | |
82 | // Fill data the way it tests two modes: max_val < 0 and max_val >= 0; |
83 | // Test max_val < 0 by using only negative numbers to check correct max_val |
84 | // subtraction, mostly if library used signed value, not abs. |
85 | // Test max_val >= 0 by exceeding `exp_ovfl_arg` value to check answer |
86 | // does not contain +infinity (nan). |
87 | // Distribute several top-1 values to check softmax works right. Also use |
88 | // bit more top-2 values so they contribute in final exp sum as well. Fill |
89 | // much more values with top-3 to check we apply correct maths for whole |
90 | // input. |
91 | // Filling data such way prevents cancellation error for LOGSOFTMAX due to |
92 | // log(sum(x_j)) won't be close to zero as in case of single top-1 value. |
93 | |
94 | // Do fixed partitioning to have same filling for any number of threads. |
95 | const int64_t n_chunks = 16; |
96 | const int64_t chunk_size = div_up(outer_size, n_chunks); |
97 | |
98 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
99 | int64_t idx_start = idx_chunk * chunk_size; |
100 | int64_t idx_end = MIN2(idx_start + chunk_size, outer_size); |
101 | std::minstd_rand msr(idx_start + 1); |
102 | msr.discard(1); |
103 | std::vector<std::uniform_int_distribution<>> igen_top_fp { |
104 | std::uniform_int_distribution<>(1, 2), |
105 | std::uniform_int_distribution<>(2, 5), |
106 | std::uniform_int_distribution<>(5, 8)}; |
107 | std::vector<std::uniform_int_distribution<>> igen_top_int8 { |
108 | std::uniform_int_distribution<>(1, 1), |
109 | std::uniform_int_distribution<>(1, 1), |
110 | std::uniform_int_distribution<>(0, 4)}; |
111 | std::vector<std::uniform_int_distribution<>> igen_top |
112 | = dnnl_data_type_size(prb->ddt) == 1 ? igen_top_int8 |
113 | : igen_top_fp; |
114 | const int sign = (idx_chunk % 2 != 0 && prb->sdt != dnnl_u8) ? -1 : 1; |
115 | const int exp_ovfl_arg = 88 * sign; |
116 | std::vector<int> top_val { |
117 | exp_ovfl_arg + 2, exp_ovfl_arg + 1, exp_ovfl_arg}; |
118 | |
119 | for_(int64_t idx = idx_start; idx < idx_end; ++idx) |
120 | for (int64_t in = 0; in < inner_size; in++) { |
121 | std::vector<int64_t> n_top { |
122 | igen_top[0](msr), igen_top[1](msr), igen_top[2](msr)}; |
123 | int i = 2; |
124 | int64_t n_sum = n_top[0] + n_top[1] + n_top[2]; |
125 | // Adjust number of top elements to fit axis_size if needed |
126 | while (n_sum > axis_size) { |
127 | n_sum -= n_top[i]; |
128 | n_top[i] -= std::min(n_top[i], n_sum + n_top[i] - axis_size); |
129 | n_sum += n_top[i]; |
130 | i--; |
131 | } |
132 | // If number of top elements is less the axis_size, set a random |
133 | // index to start dense filling from. |
134 | std::uniform_int_distribution<> igen_as_idx(0, axis_size - n_sum); |
135 | msr.discard(2); |
136 | int64_t axis_idx_start = igen_as_idx(msr); |
137 | |
138 | i = 0; |
139 | for (int64_t as = 0; as < axis_size; as++) { |
140 | auto offset = inner_size * (idx * axis_size + as) + in; |
141 | float value = INT_MIN; |
142 | if (as >= axis_idx_start && as < axis_idx_start + n_sum) { |
143 | value = top_val[i]; |
144 | n_top[i]--; |
145 | if (n_top[i] == 0) i++; |
146 | } |
147 | mem_fp.set_elem(offset, |
148 | round_to_nearest_representable(mem_dt.dt(), value)); |
149 | } |
150 | } |
151 | }); |
152 | |
153 | SAFE(mem_dt.reorder(mem_fp), WARN); |
154 | |
155 | return OK; |
156 | } |
157 | |
158 | int fill_data_bwd( |
159 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, int seed) { |
160 | const auto nelems = mem_fp.nelems(); |
161 | const int range = 128; |
162 | |
163 | // to avoid any cancellation erros it's better to have d_dst and dst of |
164 | // different signs (refer to ref computations). |
165 | // softmax := (d_dst - SUM (d_dst * dst); keep +d_dst and -dst. |
166 | // logsoftmax := d_dst - exp(dst) * SUM (d_dst); keep -d_dst and +dst. |
167 | // seed decides about the sign. |
168 | const float sign = seed % 2 == 0 ? 1.f : -1.f; |
169 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
170 | const float gen = ((11 * i) + 37 + 19 * seed) % range; |
171 | const float value = sign * gen / range; |
172 | mem_fp.set_elem(i, value); |
173 | }); |
174 | |
175 | SAFE(mem_dt.reorder(mem_fp), WARN); |
176 | |
177 | return OK; |
178 | } |
179 | |
180 | int fill_scales( |
181 | const attr_t &attr, int arg, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) { |
182 | const auto nelems = mem_fp.nelems(); |
183 | if (nelems == 0) return OK; |
184 | |
185 | assert(mem_dt.nelems() == mem_fp.nelems()); |
186 | |
187 | const auto &scales = attr.scales.get(arg); |
188 | |
189 | /* Do fixed partitioning to have same filling for any number of threads */ |
190 | const int64_t n_chunks = 16; |
191 | const int64_t chunk_size = div_up(nelems, n_chunks); |
192 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
193 | int64_t idx_start = idx_chunk * chunk_size; |
194 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
195 | // Note: we use a different seed for each chunk to avoid |
196 | // repeating patterns. We could use discard(idx_start) too but |
197 | // it has a complexity in O(idx_start). We also add 1 to avoid |
198 | // seeding with 0. |
199 | std::minstd_rand int_seed(idx_start + 1); |
200 | int_seed.discard(1); |
201 | |
202 | std::uniform_int_distribution<> gen(-5, 5); |
203 | |
204 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
205 | int pow2 = gen(int_seed); |
206 | int pow2_shift = 1 << std::abs(pow2); |
207 | const float gen_val = pow2 < 0 ? (1.f / pow2_shift) : pow2_shift; |
208 | const float fixed_val = scales.scale; |
209 | const float val = nelems == 1 ? fixed_val : gen_val; |
210 | mem_fp.set_elem(idx, val); |
211 | } |
212 | }); |
213 | |
214 | SAFE(mem_dt.reorder(mem_fp), WARN); |
215 | |
216 | return OK; |
217 | } |
218 | |
219 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
220 | skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res); |
221 | skip_unimplemented_sum_po(prb->attr, res); |
222 | } |
223 | |
224 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
225 | // See `skip_invalid_inplace` for details. |
226 | if (prb->inplace) { |
227 | skip_invalid_inplace(res, prb->sdt, prb->ddt, prb->stag, prb->dtag); |
228 | if (res->state == SKIPPED) return; |
229 | } |
230 | } |
231 | |
232 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
233 | const args_t &ref_args) { |
234 | const auto trh_dt = (prb->dir & FLAG_FWD) ? prb->ddt : prb->sdt; |
235 | const float trh_coeff_log = prb->alg == LOGSOFTMAX ? 5 : 1; |
236 | const float trh_coeff_f32 = trh_dt == dnnl_f32 ? 10.f : 1.f; |
237 | const float trh_coeff_bwd = (prb->dir & FLAG_FWD) ? 1.f : 4.f; |
238 | const float trh = trh_coeff_log * trh_coeff_bwd * trh_coeff_f32 |
239 | * epsilon_dt(trh_dt); |
240 | cmp.set_threshold(trh); |
241 | |
242 | const int64_t axis_size = prb->dims[prb->axis]; |
243 | const int64_t n_zeros = (prb->ddt == dnnl_s8 || prb->ddt == dnnl_u8) |
244 | ? (axis_size - 1) |
245 | : MAX2(0, axis_size - 8); |
246 | float zero_trust_percent = 100.f * n_zeros / axis_size; |
247 | // Note: |
248 | // * Logsoftmax over axis of size `1` does not make any sense. |
249 | // * Logsoftmax for u8 dst does not make any sense either. |
250 | if (prb->alg == LOGSOFTMAX && (axis_size == 1 || prb->ddt == dnnl_u8)) |
251 | zero_trust_percent = 100.f; |
252 | if (prb->dir & FLAG_BWD) zero_trust_percent = 30.f; |
253 | cmp.set_zero_trust_percent(zero_trust_percent); |
254 | |
255 | const auto softmax_add_check |
256 | = [&](const compare::compare_t::driver_check_func_args_t &args) { |
257 | // SSE4.1 and OpenCL rdiff tolerance is too high for |
258 | // certain scenarios. |
259 | return args.diff < epsilon_dt(args.dt); |
260 | }; |
261 | cmp.set_driver_check_function(softmax_add_check); |
262 | } |
263 | |
264 | int doit(const prb_t *prb, res_t *res) { |
265 | if (bench_mode == LIST) return res->state = LISTED, OK; |
266 | |
267 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
268 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
269 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
270 | |
271 | auto const_pd = query_pd(prim); |
272 | |
273 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
274 | const auto &test_engine = get_test_engine(); |
275 | const auto &ref_engine = get_cpu_engine(); |
276 | |
277 | dnn_mem_t src_dt, placeholder_dst_dt; |
278 | dnn_mem_t &dst_dt = prb->inplace && (prb->dir & FLAG_FWD) |
279 | ? src_dt |
280 | : placeholder_dst_dt; |
281 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
282 | |
283 | dnn_mem_t d_dst_dt, placeholder_d_src_dt; |
284 | dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt; |
285 | |
286 | const dnnl_dims_t scale_dims = {1}; |
287 | auto scales_md = dnn_mem_t::init_md(1, scale_dims, dnnl_f32, tag::abx); |
288 | dnn_mem_t src_scales_dt(scales_md, test_engine); |
289 | dnn_mem_t dst_scales_dt(scales_md, test_engine); |
290 | |
291 | args_t args, ref_args; |
292 | |
293 | if (prb->dir & FLAG_FWD) { |
294 | const auto &src_md = query_md(const_pd, DNNL_ARG_SRC); |
295 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
296 | |
297 | src_dt = dnn_mem_t(src_md, test_engine); |
298 | if (!prb->inplace) { |
299 | placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); |
300 | } |
301 | |
302 | dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine); |
303 | dnn_mem_t &dst_fp = src_fp; // in-place reference |
304 | SAFE(fill_data_fwd(prb, src_dt, src_fp), WARN); |
305 | |
306 | dnn_mem_t src_scales_fp(scales_md, ref_engine); |
307 | dnn_mem_t dst_scales_fp(scales_md, ref_engine); |
308 | fill_scales(prb->attr, DNNL_ARG_SRC, src_scales_dt, src_scales_fp); |
309 | fill_scales(prb->attr, DNNL_ARG_DST, dst_scales_dt, dst_scales_fp); |
310 | |
311 | args.set(DNNL_ARG_SRC, src_dt); |
312 | args.set(DNNL_ARG_DST, dst_dt); |
313 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
314 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
315 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
316 | |
317 | SAFE(execute_and_wait(prim, args, res), WARN); |
318 | |
319 | if (is_bench_mode(CORR)) { |
320 | ref_args.set(DNNL_ARG_SRC, src_fp); |
321 | ref_args.set(DNNL_ARG_DST, dst_fp); |
322 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp); |
323 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp); |
324 | |
325 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
326 | } |
327 | } else { |
328 | const auto &dst_md = query_md(const_pd, DNNL_ARG_DST); |
329 | const auto &d_dst_md = query_md(const_pd, DNNL_ARG_DIFF_DST); |
330 | const auto &d_src_md = query_md(const_pd, DNNL_ARG_DIFF_SRC); |
331 | |
332 | placeholder_dst_dt = dnn_mem_t(dst_md, test_engine); |
333 | d_dst_dt = dnn_mem_t(d_dst_md, test_engine); |
334 | if (!prb->inplace) { |
335 | placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine); |
336 | } |
337 | |
338 | dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine); |
339 | dnn_mem_t d_dst_fp(d_dst_md, dnnl_f32, tag::abx, ref_engine); |
340 | dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference |
341 | |
342 | const bool neg_sign = prb->alg == SOFTMAX ? true : false; |
343 | SAFE(fill_data_bwd(prb, dst_dt, dst_fp, neg_sign), WARN); |
344 | SAFE(fill_data_bwd(prb, d_dst_dt, d_dst_fp, !neg_sign), WARN); |
345 | |
346 | args.set(DNNL_ARG_DST, dst_dt); |
347 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
348 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
349 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
350 | |
351 | SAFE(execute_and_wait(prim, args, res), WARN); |
352 | |
353 | if (is_bench_mode(CORR)) { |
354 | ref_args.set(DNNL_ARG_DST, dst_fp); |
355 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
356 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
357 | |
358 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
359 | } |
360 | } |
361 | |
362 | return measure_perf(prb->ctx_exe, res, prim, args); |
363 | } |
364 | |
365 | } // namespace softmax |
366 | |