1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <float.h>
19#include <math.h>
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/cpu_primitive.hpp"
26
27#include "cpu/ref_io_helper.hpp"
28#include "cpu/ref_softmax.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34static bool is_padding(const memory_desc_wrapper &md) {
35 for (int i = 0; i < md.ndims(); i++)
36 if (md.dims()[i] != md.padded_dims()[i]) return true;
37 return false;
38}
39
40status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
41 using namespace memory_tracking::names;
42
43 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
44 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
45
46 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
47 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
48
49 float *scratchpad_int8 = ctx.get_scratchpad_grantor().template get<float>(
50 key_softmax_interim_store);
51
52 const memory_desc_wrapper src_d(pd()->src_md());
53 const memory_desc_wrapper dst_d(pd()->dst_md());
54
55 const auto interim_dt
56 = pd()->need_int8_scratchpad() ? data_type::f32 : dst_d.data_type();
57
58 const dim_t ou_stride = pd()->outer_stride();
59 const auto is_inplace = (src == dst);
60 const auto has_padding = is_padding(dst_d);
61 const auto zero_padding = has_padding && !is_inplace;
62 const auto axis = pd()->axis();
63 const auto axis_size = pd()->axis_size(true);
64 const auto axis_blk_size = src_d.padded_dims()[axis] - src_d.dims()[axis];
65 const auto src_dt_size = types::data_type_size(pd()->src_md()->data_type);
66 const auto dst_dt_size = types::data_type_size(pd()->dst_md()->data_type);
67
68 const int nthr = pd()->nthr_;
69
70 parallel_nd_ext(nthr, outer_size_, [&](int ithr, int, dim_t ou) {
71 const void *src_data = reinterpret_cast<const char *>(src)
72 + ou * ou_stride * src_dt_size;
73 void *dst_data
74 = reinterpret_cast<char *>(dst) + ou * ou_stride * dst_dt_size;
75 void *interim_ptr = pd()->need_int8_scratchpad()
76 ? (scratchpad_int8 + ithr * axis_size)
77 : dst_data;
78
79 float space_max = -FLT_MAX;
80 float space_denom = 0;
81 constexpr int unroll_factor = 32;
82
83// Intel(R) C++ Compiler generates the maxps + shuffle pattern
84// for the max search which works faster
85#if !defined(__INTEL_COMPILER)
86 // The code below makes the compiler generate maxps instruction.
87 // rather than maxss, which is generated for the 'else' code path
88 auto max_wrapper = [](float a, float b) { return nstl::max(a, b); };
89 auto min_wrapper = [](int a, int b) { return nstl::min(a, b); };
90
91 if (channels_ < unroll_factor) {
92 float max_val = -FLT_MAX;
93 for (int i = 0; i < channels_; i++) {
94 max_val = max_wrapper(max_val,
95 io::load_float_value(src_d.data_type(), src_data, i));
96 }
97 space_max = max_val;
98 } else {
99 float max_values[unroll_factor];
100
101 for (int i = 0; i < unroll_factor; i++) {
102 max_values[i]
103 = io::load_float_value(src_d.data_type(), src_data, i);
104 }
105 for (int i = unroll_factor; i < channels_; i += unroll_factor) {
106 int offset = min_wrapper(i, channels_ - unroll_factor);
107 for (int j = 0; j < unroll_factor; j++) {
108 max_values[j] = max_wrapper(max_values[j],
109 io::load_float_value(
110 src_d.data_type(), src_data, offset + j));
111 }
112 }
113 float max_val = -FLT_MAX;
114 for (int i = 0; i < unroll_factor; i++) {
115 max_val = max_wrapper(max_val, max_values[i]);
116 }
117 space_max = max_val;
118 }
119#else
120 for (int c = 0; c < channels_; ++c)
121 space_max = nstl::max(space_max,
122 io::load_float_value(src_d.data_type(), src_data, c));
123#endif
124
125 // sub + exp + sum
126 int tail = channels_ % unroll_factor;
127 for (int i = 0; i < channels_ - tail; i += unroll_factor) {
128 PRAGMA_OMP_SIMD(reduction(+ : space_denom))
129 for (int j = 0; j < unroll_factor; j++) {
130 float s = io::load_float_value(
131 src_d.data_type(), src_data, i + j);
132 float d = s - space_max;
133 if (pd()->is_softmax()) {
134 d = expf(d);
135 space_denom += d;
136 } else if (pd()->is_logsoftmax()) {
137 space_denom += expf(d);
138 }
139
140 io::store_float_value(interim_dt, d, interim_ptr, i + j);
141 }
142 }
143 for (int i = channels_ - tail; i < channels_; i++) {
144 float s = io::load_float_value(src_d.data_type(), src_data, i);
145 float d = s - space_max;
146 if (pd()->is_softmax()) {
147 d = expf(d);
148 space_denom += d;
149 } else if (pd()->is_logsoftmax()) {
150 space_denom += expf(d);
151 }
152 io::store_float_value(interim_dt, d, interim_ptr, i);
153 }
154
155 // scal
156 if (pd()->is_softmax()) {
157 space_denom = space_denom ? (1.f / space_denom) : 1.f;
158 } else if (pd()->is_logsoftmax()) {
159 space_denom = logf(space_denom);
160 }
161 for (int c = 0; c < channels_; ++c) {
162 float d = io::load_float_value(interim_dt, interim_ptr, c);
163 float val = 0;
164 if (pd()->is_softmax()) {
165 val = d * space_denom;
166 } else if (pd()->is_logsoftmax()) {
167 val = d - space_denom;
168 }
169 val *= src_scales[0] * dst_scales[0];
170 io::store_float_value(dst_d.data_type(), val, dst_data, c);
171 }
172 if (zero_padding) {
173 PRAGMA_OMP_SIMD()
174 for (int i = 0; i < axis_blk_size; i++)
175 io::store_float_value(
176 dst_d.data_type(), 0, dst_data, channels_ + i);
177 }
178 });
179 return status::success;
180}
181
182status_t ref_softmax_fwd_t::execute_forward_generic(
183 const exec_ctx_t &ctx) const {
184 using namespace memory_tracking::names;
185
186 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
187 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
188
189 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
190 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
191
192 float *scratchpad_int8 = ctx.get_scratchpad_grantor().template get<float>(
193 key_softmax_interim_store);
194
195 const memory_desc_wrapper src_d(pd()->src_md());
196 const memory_desc_wrapper dst_d(pd()->dst_md());
197
198 void *interim_ptr = pd()->need_int8_scratchpad() ? scratchpad_int8 : dst;
199 const auto interim_dt
200 = pd()->need_int8_scratchpad() ? data_type::f32 : dst_d.data_type();
201
202 const auto is_inplace = (src == dst);
203 const auto has_padding = is_padding(dst_d);
204 if (has_padding && !is_inplace) {
205 if (dst_d.is_dense(true)) {
206 const auto res = std::div(static_cast<int>(dst_d.size()), PAGE_4K);
207 if (!res.quot)
208 std::memset(dst, 0, res.rem);
209 else
210 parallel_nd(res.quot, [&](dim_t i) {
211 const auto tail = (i + 1 == res.quot) ? res.rem : 0;
212 const auto ptr_dst = reinterpret_cast<unsigned char *>(dst)
213 + i * PAGE_4K;
214 std::memset(ptr_dst, 0, PAGE_4K + tail);
215 });
216 } else
217 // needed for submemory correctness
218 ctx.zero_pad_output(DNNL_ARG_DST);
219 }
220
221 const auto axis_size = pd()->axis_size(true);
222 const int nthr = pd()->nthr_;
223
224 parallel_nd_ext(nthr, outer_size_, [&](int ithr, int, dim_t ou) {
225 const dim_t thr_shift = ithr * axis_size;
226
227 float space_max_val = 0, space_denom_val = 0;
228 float *space_max = &space_max_val, *space_denom = &space_denom_val;
229 if (inner_size_ > 1) {
230 space_max = ctx.get_scratchpad_grantor().template get<float>(
231 key_softmax_reduction)
232 + ou * 2 * inner_size_;
233 space_denom = space_max + inner_size_;
234 }
235
236 utils::array_set(space_max, -FLT_MAX, inner_size_);
237 utils::array_set(space_denom, 0, inner_size_);
238
239 for (int in = 0; in < inner_size_; in++) {
240 dim_t ou_in_offset = ou * channels_ * inner_size_ + in;
241
242 for (int c = 0; c < channels_; c++) {
243 size_t off = src_d.off_l(ou_in_offset + c * inner_size_);
244 float s = io::load_float_value(src_d.data_type(), src, off);
245 space_max[in] = nstl::max(space_max[in], s);
246 }
247
248 for (int c = 0; c < channels_; c++) {
249 size_t src_off = src_d.off_l(ou_in_offset + c * inner_size_);
250 float s = io::load_float_value(src_d.data_type(), src, src_off);
251 float d = s - space_max[in];
252 if (pd()->is_softmax()) {
253 d = expf(d);
254 space_denom[in] += d;
255 } else if (pd()->is_logsoftmax()) {
256 space_denom[in] += expf(d);
257 }
258 size_t dst_off = dst_d.off_l(ou_in_offset + c * inner_size_);
259 size_t interim_off = pd()->need_int8_scratchpad()
260 ? thr_shift + c
261 : dst_off;
262 io::store_float_value(interim_dt, d, interim_ptr, interim_off);
263 }
264
265 if (pd()->is_logsoftmax()) {
266 space_denom[in] = logf(space_denom[in]);
267 }
268
269 for (int c = 0; c < channels_; c++) {
270 size_t dst_off = dst_d.off_l(ou_in_offset + c * inner_size_);
271 size_t interim_off = pd()->need_int8_scratchpad()
272 ? thr_shift + c
273 : dst_off;
274 float d = io::load_float_value(
275 interim_dt, interim_ptr, interim_off);
276 float sd = space_denom[in];
277 if (pd()->is_softmax()) {
278 d /= sd;
279 } else if (pd()->is_logsoftmax()) {
280 d -= sd;
281 }
282 d *= src_scales[0] * dst_scales[0];
283 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
284 }
285 }
286 });
287 return status::success;
288}
289
290// softmax along last physical dimension
291status_t ref_softmax_bwd_t::execute_backward_dense(
292 const exec_ctx_t &ctx) const {
293 auto dst = CTX_IN_MEM(const void *, DNNL_ARG_DST);
294 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
295 auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
296
297 const memory_desc_wrapper dst_d(pd()->dst_md());
298 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
299 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
300
301 const auto ou_stride = pd()->outer_stride();
302
303 parallel_nd(outer_size_, [&](dim_t ou) {
304 float sbr = 0;
305 size_t off = ou * ou_stride;
306 if (pd()->is_softmax()) {
307 for (size_t loff = off; loff < off + channels_; ++loff) {
308 float d = io::load_float_value(dst_d.data_type(), dst, loff);
309 float dd = io::load_float_value(
310 diff_dst_d.data_type(), diff_dst, loff);
311 sbr += dd * d;
312 }
313 for (size_t loff = off; loff < off + channels_; ++loff) {
314 float d = io::load_float_value(dst_d.data_type(), dst, loff);
315 float dd = io::load_float_value(
316 diff_dst_d.data_type(), diff_dst, loff);
317 float val = d * (dd - sbr);
318 io::store_float_value(
319 diff_src_d.data_type(), val, diff_src, loff);
320 }
321 } else if (pd()->is_logsoftmax()) {
322 for (size_t loff = off; loff < off + channels_; ++loff) {
323 float dd = io::load_float_value(
324 diff_dst_d.data_type(), diff_dst, loff);
325 sbr += dd;
326 }
327 for (size_t loff = off; loff < off + channels_; ++loff) {
328 float d = io::load_float_value(dst_d.data_type(), dst, loff);
329 float dd = io::load_float_value(
330 diff_dst_d.data_type(), diff_dst, loff);
331 float val = dd - expf(d) * sbr;
332 io::store_float_value(
333 diff_src_d.data_type(), val, diff_src, loff);
334 }
335 }
336 });
337 return status::success;
338}
339
340status_t ref_softmax_bwd_t::execute_backward_generic(
341 const exec_ctx_t &ctx) const {
342 auto dst = CTX_IN_MEM(const void *, DNNL_ARG_DST);
343 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
344 auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
345
346 const memory_desc_wrapper dst_d(pd()->dst_md());
347 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
348 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
349
350 const auto is_inplace = (diff_dst == diff_src);
351 const auto has_padding = is_padding(diff_dst_d);
352 if (has_padding && !is_inplace) {
353 if (diff_dst_d.is_dense(true)) {
354 const auto res
355 = std::div(static_cast<int>(diff_dst_d.size()), PAGE_4K);
356 if (!res.quot)
357 std::memset(diff_src, 0, res.rem);
358 else
359 parallel_nd(res.quot, [&](dim_t i) {
360 const auto tail = (i + 1 == res.quot) ? res.rem : 0;
361 const auto ptr_dst
362 = reinterpret_cast<unsigned char *>(diff_src)
363 + i * PAGE_4K;
364 std::memset(ptr_dst, 0, PAGE_4K + tail);
365 });
366 } else
367 // needed for submemory correctness
368 ctx.zero_pad_output(DNNL_ARG_DIFF_SRC);
369 }
370
371 parallel_nd(outer_size_, inner_size_, [&](dim_t ou, dim_t in) {
372 dim_t ou_in_offset = ou * channels_ * inner_size_ + in;
373 float sbr = 0;
374 for (int c = 0; c < channels_; ++c) {
375 auto diff_dst_off
376 = diff_dst_d.off_l(ou_in_offset + c * inner_size_);
377 float dd = io::load_float_value(
378 diff_dst_d.data_type(), diff_dst, diff_dst_off);
379 if (pd()->is_softmax()) {
380 auto dst_off = dst_d.off_l(ou_in_offset + c * inner_size_);
381 float d = io::load_float_value(dst_d.data_type(), dst, dst_off);
382 sbr += dd * d;
383 } else if (pd()->is_logsoftmax()) {
384 sbr += dd;
385 }
386 }
387
388 for (int c = 0; c < channels_; ++c) {
389 auto diff_dst_off
390 = diff_dst_d.off_l(ou_in_offset + c * inner_size_);
391 auto dst_off = dst_d.off_l(ou_in_offset + c * inner_size_);
392 float d = io::load_float_value(dst_d.data_type(), dst, dst_off);
393 float dd = io::load_float_value(
394 diff_dst_d.data_type(), diff_dst, diff_dst_off);
395 float val = 0;
396 if (pd()->is_softmax()) {
397 val = d * (dd - sbr);
398 } else if (pd()->is_logsoftmax()) {
399 val = dd - expf(d) * sbr;
400 }
401 auto diff_src_off
402 = diff_src_d.off_l(ou_in_offset + c * inner_size_);
403 io::store_float_value(
404 diff_src_d.data_type(), val, diff_src, diff_src_off);
405 }
406 });
407 return status::success;
408}
409
410} // namespace cpu
411} // namespace impl
412} // namespace dnnl
413
414// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
415