1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "softmax/softmax.hpp" |
20 | |
21 | namespace softmax { |
22 | |
23 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
26 | const dnn_mem_t &src_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
27 | const dnn_mem_t &dst_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
28 | |
29 | float *dst_ptr = (float *)dst; |
30 | |
31 | const auto alg = prb->alg; |
32 | int64_t outer_size {0}, inner_size {0}, axis_size {0}; |
33 | get_sizes(prb, outer_size, inner_size, axis_size); |
34 | |
35 | assert(src_scale.nelems() == 1 && dst_scale.nelems() == 1); |
36 | const float output_scale = src_scale.get_elem(0) / dst_scale.get_elem(0); |
37 | |
38 | benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) { |
39 | float space_denom = 0.; |
40 | float space_max = -FLT_MAX; |
41 | int64_t ou_in_offset = ou * axis_size * inner_size + in; |
42 | |
43 | for (int64_t as = 0; as < axis_size; ++as) { |
44 | int64_t idx = ou_in_offset + as * inner_size; |
45 | space_max = MAX2(space_max, src.get_elem(idx)); |
46 | } |
47 | |
48 | for (int64_t as = 0; as < axis_size; ++as) { |
49 | int64_t idx = ou_in_offset + as * inner_size; |
50 | float s = src.get_elem(idx); |
51 | if (alg == SOFTMAX) { |
52 | float D = dst_ptr[idx] = expf(s - space_max); |
53 | space_denom += D; |
54 | } else if (alg == LOGSOFTMAX) { |
55 | float D = dst_ptr[idx] = s - space_max; |
56 | space_denom += expf(D); |
57 | } |
58 | } |
59 | |
60 | if (alg == SOFTMAX) { |
61 | space_denom = space_denom ? (1.f / space_denom) : 1.f; |
62 | } else if (alg == LOGSOFTMAX) { |
63 | space_denom = logf(space_denom); |
64 | } |
65 | |
66 | for (int64_t as = 0; as < axis_size; ++as) { |
67 | int64_t idx = ou_in_offset + as * inner_size; |
68 | if (alg == SOFTMAX) { |
69 | dst_ptr[idx] *= space_denom; |
70 | } else if (alg == LOGSOFTMAX) { |
71 | dst_ptr[idx] -= space_denom; |
72 | } |
73 | |
74 | dst_ptr[idx] *= output_scale; |
75 | } |
76 | }); |
77 | } |
78 | |
79 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
80 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
81 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
82 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
83 | |
84 | float *d_src_ptr = (float *)d_src; |
85 | |
86 | const auto alg = prb->alg; |
87 | int64_t outer_size {0}, inner_size {0}, axis_size {0}; |
88 | get_sizes(prb, outer_size, inner_size, axis_size); |
89 | |
90 | benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) { |
91 | float part_deriv_sum = 0.; |
92 | int64_t ou_in_offset = ou * axis_size * inner_size + in; |
93 | |
94 | for (int64_t as = 0; as < axis_size; ++as) { |
95 | int64_t idx = ou_in_offset + as * inner_size; |
96 | float d = dst.get_elem(idx); |
97 | float dd = d_dst.get_elem(idx); |
98 | if (alg == SOFTMAX) { |
99 | part_deriv_sum += dd * d; |
100 | } else if (alg == LOGSOFTMAX) { |
101 | part_deriv_sum += dd; |
102 | } |
103 | } |
104 | |
105 | for (int64_t as = 0; as < axis_size; ++as) { |
106 | int64_t idx = ou_in_offset + as * inner_size; |
107 | float d = dst.get_elem(idx); |
108 | float dd = d_dst.get_elem(idx); |
109 | if (alg == SOFTMAX) { |
110 | d_src_ptr[idx] = d * (dd - part_deriv_sum); |
111 | } else if (alg == LOGSOFTMAX) { |
112 | d_src_ptr[idx] = dd - expf(d) * part_deriv_sum; |
113 | } |
114 | } |
115 | }); |
116 | } |
117 | |
118 | void compute_ref( |
119 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
120 | if (prb->dir & FLAG_FWD) |
121 | compute_ref_fwd(prb, args); |
122 | else |
123 | compute_ref_bwd(prb, args); |
124 | } |
125 | |
126 | } // namespace softmax |
127 | |