1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "softmax/softmax.hpp"
20
21namespace softmax {
22
23void compute_ref_fwd(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
26 const dnn_mem_t &src_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
27 const dnn_mem_t &dst_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
28
29 float *dst_ptr = (float *)dst;
30
31 const auto alg = prb->alg;
32 int64_t outer_size {0}, inner_size {0}, axis_size {0};
33 get_sizes(prb, outer_size, inner_size, axis_size);
34
35 assert(src_scale.nelems() == 1 && dst_scale.nelems() == 1);
36 const float output_scale = src_scale.get_elem(0) / dst_scale.get_elem(0);
37
38 benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) {
39 float space_denom = 0.;
40 float space_max = -FLT_MAX;
41 int64_t ou_in_offset = ou * axis_size * inner_size + in;
42
43 for (int64_t as = 0; as < axis_size; ++as) {
44 int64_t idx = ou_in_offset + as * inner_size;
45 space_max = MAX2(space_max, src.get_elem(idx));
46 }
47
48 for (int64_t as = 0; as < axis_size; ++as) {
49 int64_t idx = ou_in_offset + as * inner_size;
50 float s = src.get_elem(idx);
51 if (alg == SOFTMAX) {
52 float D = dst_ptr[idx] = expf(s - space_max);
53 space_denom += D;
54 } else if (alg == LOGSOFTMAX) {
55 float D = dst_ptr[idx] = s - space_max;
56 space_denom += expf(D);
57 }
58 }
59
60 if (alg == SOFTMAX) {
61 space_denom = space_denom ? (1.f / space_denom) : 1.f;
62 } else if (alg == LOGSOFTMAX) {
63 space_denom = logf(space_denom);
64 }
65
66 for (int64_t as = 0; as < axis_size; ++as) {
67 int64_t idx = ou_in_offset + as * inner_size;
68 if (alg == SOFTMAX) {
69 dst_ptr[idx] *= space_denom;
70 } else if (alg == LOGSOFTMAX) {
71 dst_ptr[idx] -= space_denom;
72 }
73
74 dst_ptr[idx] *= output_scale;
75 }
76 });
77}
78
79void compute_ref_bwd(const prb_t *prb, const args_t &args) {
80 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
81 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
82 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
83
84 float *d_src_ptr = (float *)d_src;
85
86 const auto alg = prb->alg;
87 int64_t outer_size {0}, inner_size {0}, axis_size {0};
88 get_sizes(prb, outer_size, inner_size, axis_size);
89
90 benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) {
91 float part_deriv_sum = 0.;
92 int64_t ou_in_offset = ou * axis_size * inner_size + in;
93
94 for (int64_t as = 0; as < axis_size; ++as) {
95 int64_t idx = ou_in_offset + as * inner_size;
96 float d = dst.get_elem(idx);
97 float dd = d_dst.get_elem(idx);
98 if (alg == SOFTMAX) {
99 part_deriv_sum += dd * d;
100 } else if (alg == LOGSOFTMAX) {
101 part_deriv_sum += dd;
102 }
103 }
104
105 for (int64_t as = 0; as < axis_size; ++as) {
106 int64_t idx = ou_in_offset + as * inner_size;
107 float d = dst.get_elem(idx);
108 float dd = d_dst.get_elem(idx);
109 if (alg == SOFTMAX) {
110 d_src_ptr[idx] = d * (dd - part_deriv_sum);
111 } else if (alg == LOGSOFTMAX) {
112 d_src_ptr[idx] = dd - expf(d) * part_deriv_sum;
113 }
114 }
115 });
116}
117
118void compute_ref(
119 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
120 if (prb->dir & FLAG_FWD)
121 compute_ref_fwd(prb, args);
122 else
123 compute_ref_bwd(prb, args);
124}
125
126} // namespace softmax
127