1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <random>
18
19#include <float.h>
20#include <math.h>
21#include <stdio.h>
22#include <stdlib.h>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "utils/parallel.hpp"
27
28#include "dnnl_common.hpp"
29#include "dnnl_memory.hpp"
30
31#include "softmax/softmax.hpp"
32
33namespace softmax {
34
35dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
36 const prb_t *prb = init_pd_args.prb;
37
38 auto dst_d = dnn_mem_t::init_md(
39 prb->ndims, prb->dims.data(), prb->ddt, prb->dtag);
40
41 dnnl_alg_kind_t alg_kind = dnnl_softmax_accurate;
42 if (prb->alg == LOGSOFTMAX) alg_kind = dnnl_softmax_log;
43
44 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
45 create_dnnl_attr(prb->attr, attr_args_t()));
46
47 if (prb->dir & FLAG_FWD) {
48 auto src_d = dnn_mem_t::init_md(
49 prb->ndims, prb->dims.data(), prb->sdt, prb->stag);
50
51 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
52 : dnnl_forward_training;
53
54 DNN_SAFE_STATUS(dnnl_softmax_forward_primitive_desc_create(
55 &init_pd_args.pd, init_pd_args.engine, prop, alg_kind, src_d,
56 dst_d, prb->axis, dnnl_attr));
57 } else {
58 // Re-create dst_md with source tag if dst was not specified, immitating
59 // default value.
60 if (prb->dtag == tag::any) {
61 dst_d = dnn_mem_t::init_md(
62 prb->ndims, prb->dims.data(), prb->ddt, prb->stag);
63 }
64
65 auto diff_src_d = dnn_mem_t::init_md(
66 prb->ndims, prb->dims.data(), prb->sdt, tag::any);
67 auto diff_dst_d = dnn_mem_t::init_md(
68 prb->ndims, prb->dims.data(), prb->ddt, tag::any);
69
70 DNN_SAFE_STATUS(dnnl_softmax_backward_primitive_desc_create(
71 &init_pd_args.pd, init_pd_args.engine, alg_kind, diff_src_d,
72 diff_dst_d, dst_d, prb->axis, init_pd_args.hint, dnnl_attr));
73 }
74
75 return dnnl_success;
76}
77
78int fill_data_fwd(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
79 int64_t outer_size = 0, inner_size = 0, axis_size = 0;
80 get_sizes(prb, outer_size, inner_size, axis_size);
81
82 // Fill data the way it tests two modes: max_val < 0 and max_val >= 0;
83 // Test max_val < 0 by using only negative numbers to check correct max_val
84 // subtraction, mostly if library used signed value, not abs.
85 // Test max_val >= 0 by exceeding `exp_ovfl_arg` value to check answer
86 // does not contain +infinity (nan).
87 // Distribute several top-1 values to check softmax works right. Also use
88 // bit more top-2 values so they contribute in final exp sum as well. Fill
89 // much more values with top-3 to check we apply correct maths for whole
90 // input.
91 // Filling data such way prevents cancellation error for LOGSOFTMAX due to
92 // log(sum(x_j)) won't be close to zero as in case of single top-1 value.
93
94 // Do fixed partitioning to have same filling for any number of threads.
95 const int64_t n_chunks = 16;
96 const int64_t chunk_size = div_up(outer_size, n_chunks);
97
98 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
99 int64_t idx_start = idx_chunk * chunk_size;
100 int64_t idx_end = MIN2(idx_start + chunk_size, outer_size);
101 std::minstd_rand msr(idx_start + 1);
102 msr.discard(1);
103 std::vector<std::uniform_int_distribution<>> igen_top_fp {
104 std::uniform_int_distribution<>(1, 2),
105 std::uniform_int_distribution<>(2, 5),
106 std::uniform_int_distribution<>(5, 8)};
107 std::vector<std::uniform_int_distribution<>> igen_top_int8 {
108 std::uniform_int_distribution<>(1, 1),
109 std::uniform_int_distribution<>(1, 1),
110 std::uniform_int_distribution<>(0, 4)};
111 std::vector<std::uniform_int_distribution<>> igen_top
112 = dnnl_data_type_size(prb->ddt) == 1 ? igen_top_int8
113 : igen_top_fp;
114 const int sign = (idx_chunk % 2 != 0 && prb->sdt != dnnl_u8) ? -1 : 1;
115 const int exp_ovfl_arg = 88 * sign;
116 std::vector<int> top_val {
117 exp_ovfl_arg + 2, exp_ovfl_arg + 1, exp_ovfl_arg};
118
119 for_(int64_t idx = idx_start; idx < idx_end; ++idx)
120 for (int64_t in = 0; in < inner_size; in++) {
121 std::vector<int64_t> n_top {
122 igen_top[0](msr), igen_top[1](msr), igen_top[2](msr)};
123 int i = 2;
124 int64_t n_sum = n_top[0] + n_top[1] + n_top[2];
125 // Adjust number of top elements to fit axis_size if needed
126 while (n_sum > axis_size) {
127 n_sum -= n_top[i];
128 n_top[i] -= std::min(n_top[i], n_sum + n_top[i] - axis_size);
129 n_sum += n_top[i];
130 i--;
131 }
132 // If number of top elements is less the axis_size, set a random
133 // index to start dense filling from.
134 std::uniform_int_distribution<> igen_as_idx(0, axis_size - n_sum);
135 msr.discard(2);
136 int64_t axis_idx_start = igen_as_idx(msr);
137
138 i = 0;
139 for (int64_t as = 0; as < axis_size; as++) {
140 auto offset = inner_size * (idx * axis_size + as) + in;
141 float value = INT_MIN;
142 if (as >= axis_idx_start && as < axis_idx_start + n_sum) {
143 value = top_val[i];
144 n_top[i]--;
145 if (n_top[i] == 0) i++;
146 }
147 mem_fp.set_elem(offset,
148 round_to_nearest_representable(mem_dt.dt(), value));
149 }
150 }
151 });
152
153 SAFE(mem_dt.reorder(mem_fp), WARN);
154
155 return OK;
156}
157
158int fill_data_bwd(
159 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, int seed) {
160 const auto nelems = mem_fp.nelems();
161 const int range = 128;
162
163 // to avoid any cancellation erros it's better to have d_dst and dst of
164 // different signs (refer to ref computations).
165 // softmax := (d_dst - SUM (d_dst * dst); keep +d_dst and -dst.
166 // logsoftmax := d_dst - exp(dst) * SUM (d_dst); keep -d_dst and +dst.
167 // seed decides about the sign.
168 const float sign = seed % 2 == 0 ? 1.f : -1.f;
169 benchdnn_parallel_nd(nelems, [&](int64_t i) {
170 const float gen = ((11 * i) + 37 + 19 * seed) % range;
171 const float value = sign * gen / range;
172 mem_fp.set_elem(i, value);
173 });
174
175 SAFE(mem_dt.reorder(mem_fp), WARN);
176
177 return OK;
178}
179
180int fill_scales(
181 const attr_t &attr, int arg, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
182 const auto nelems = mem_fp.nelems();
183 if (nelems == 0) return OK;
184
185 assert(mem_dt.nelems() == mem_fp.nelems());
186
187 const auto &scales = attr.scales.get(arg);
188
189 /* Do fixed partitioning to have same filling for any number of threads */
190 const int64_t n_chunks = 16;
191 const int64_t chunk_size = div_up(nelems, n_chunks);
192 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
193 int64_t idx_start = idx_chunk * chunk_size;
194 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
195 // Note: we use a different seed for each chunk to avoid
196 // repeating patterns. We could use discard(idx_start) too but
197 // it has a complexity in O(idx_start). We also add 1 to avoid
198 // seeding with 0.
199 std::minstd_rand int_seed(idx_start + 1);
200 int_seed.discard(1);
201
202 std::uniform_int_distribution<> gen(-5, 5);
203
204 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
205 int pow2 = gen(int_seed);
206 int pow2_shift = 1 << std::abs(pow2);
207 const float gen_val = pow2 < 0 ? (1.f / pow2_shift) : pow2_shift;
208 const float fixed_val = scales.scale;
209 const float val = nelems == 1 ? fixed_val : gen_val;
210 mem_fp.set_elem(idx, val);
211 }
212 });
213
214 SAFE(mem_dt.reorder(mem_fp), WARN);
215
216 return OK;
217}
218
219void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
220 skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res);
221 skip_unimplemented_sum_po(prb->attr, res);
222}
223
224void skip_invalid_prb(const prb_t *prb, res_t *res) {
225 // See `skip_invalid_inplace` for details.
226 if (prb->inplace) {
227 skip_invalid_inplace(res, prb->sdt, prb->ddt, prb->stag, prb->dtag);
228 if (res->state == SKIPPED) return;
229 }
230}
231
232void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
233 const args_t &ref_args) {
234 const auto trh_dt = (prb->dir & FLAG_FWD) ? prb->ddt : prb->sdt;
235 const float trh_coeff_log = prb->alg == LOGSOFTMAX ? 5 : 1;
236 const float trh_coeff_f32 = trh_dt == dnnl_f32 ? 10.f : 1.f;
237 const float trh_coeff_bwd = (prb->dir & FLAG_FWD) ? 1.f : 4.f;
238 const float trh = trh_coeff_log * trh_coeff_bwd * trh_coeff_f32
239 * epsilon_dt(trh_dt);
240 cmp.set_threshold(trh);
241
242 const int64_t axis_size = prb->dims[prb->axis];
243 const int64_t n_zeros = (prb->ddt == dnnl_s8 || prb->ddt == dnnl_u8)
244 ? (axis_size - 1)
245 : MAX2(0, axis_size - 8);
246 float zero_trust_percent = 100.f * n_zeros / axis_size;
247 // Note:
248 // * Logsoftmax over axis of size `1` does not make any sense.
249 // * Logsoftmax for u8 dst does not make any sense either.
250 if (prb->alg == LOGSOFTMAX && (axis_size == 1 || prb->ddt == dnnl_u8))
251 zero_trust_percent = 100.f;
252 if (prb->dir & FLAG_BWD) zero_trust_percent = 30.f;
253 cmp.set_zero_trust_percent(zero_trust_percent);
254
255 const auto softmax_add_check
256 = [&](const compare::compare_t::driver_check_func_args_t &args) {
257 // SSE4.1 and OpenCL rdiff tolerance is too high for
258 // certain scenarios.
259 return args.diff < epsilon_dt(args.dt);
260 };
261 cmp.set_driver_check_function(softmax_add_check);
262}
263
264int doit(const prb_t *prb, res_t *res) {
265 if (bench_mode == LIST) return res->state = LISTED, OK;
266
267 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
268 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
269 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
270
271 auto const_pd = query_pd(prim);
272
273 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
274 const auto &test_engine = get_test_engine();
275 const auto &ref_engine = get_cpu_engine();
276
277 dnn_mem_t src_dt, placeholder_dst_dt;
278 dnn_mem_t &dst_dt = prb->inplace && (prb->dir & FLAG_FWD)
279 ? src_dt
280 : placeholder_dst_dt;
281 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
282
283 dnn_mem_t d_dst_dt, placeholder_d_src_dt;
284 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
285
286 const dnnl_dims_t scale_dims = {1};
287 auto scales_md = dnn_mem_t::init_md(1, scale_dims, dnnl_f32, tag::abx);
288 dnn_mem_t src_scales_dt(scales_md, test_engine);
289 dnn_mem_t dst_scales_dt(scales_md, test_engine);
290
291 args_t args, ref_args;
292
293 if (prb->dir & FLAG_FWD) {
294 const auto &src_md = query_md(const_pd, DNNL_ARG_SRC);
295 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
296
297 src_dt = dnn_mem_t(src_md, test_engine);
298 if (!prb->inplace) {
299 placeholder_dst_dt = dnn_mem_t(dst_md, test_engine);
300 }
301
302 dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine);
303 dnn_mem_t &dst_fp = src_fp; // in-place reference
304 SAFE(fill_data_fwd(prb, src_dt, src_fp), WARN);
305
306 dnn_mem_t src_scales_fp(scales_md, ref_engine);
307 dnn_mem_t dst_scales_fp(scales_md, ref_engine);
308 fill_scales(prb->attr, DNNL_ARG_SRC, src_scales_dt, src_scales_fp);
309 fill_scales(prb->attr, DNNL_ARG_DST, dst_scales_dt, dst_scales_fp);
310
311 args.set(DNNL_ARG_SRC, src_dt);
312 args.set(DNNL_ARG_DST, dst_dt);
313 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
314 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
315 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
316
317 SAFE(execute_and_wait(prim, args, res), WARN);
318
319 if (is_bench_mode(CORR)) {
320 ref_args.set(DNNL_ARG_SRC, src_fp);
321 ref_args.set(DNNL_ARG_DST, dst_fp);
322 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp);
323 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp);
324
325 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
326 }
327 } else {
328 const auto &dst_md = query_md(const_pd, DNNL_ARG_DST);
329 const auto &d_dst_md = query_md(const_pd, DNNL_ARG_DIFF_DST);
330 const auto &d_src_md = query_md(const_pd, DNNL_ARG_DIFF_SRC);
331
332 placeholder_dst_dt = dnn_mem_t(dst_md, test_engine);
333 d_dst_dt = dnn_mem_t(d_dst_md, test_engine);
334 if (!prb->inplace) {
335 placeholder_d_src_dt = dnn_mem_t(d_src_md, test_engine);
336 }
337
338 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine);
339 dnn_mem_t d_dst_fp(d_dst_md, dnnl_f32, tag::abx, ref_engine);
340 dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference
341
342 const bool neg_sign = prb->alg == SOFTMAX ? true : false;
343 SAFE(fill_data_bwd(prb, dst_dt, dst_fp, neg_sign), WARN);
344 SAFE(fill_data_bwd(prb, d_dst_dt, d_dst_fp, !neg_sign), WARN);
345
346 args.set(DNNL_ARG_DST, dst_dt);
347 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
348 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
349 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
350
351 SAFE(execute_and_wait(prim, args, res), WARN);
352
353 if (is_bench_mode(CORR)) {
354 ref_args.set(DNNL_ARG_DST, dst_fp);
355 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
356 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
357
358 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
359 }
360 }
361
362 return measure_perf(prb->ctx_exe, res, prim, args);
363}
364
365} // namespace softmax
366