1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <float.h> |
19 | #include <math.h> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/cpu_primitive.hpp" |
26 | |
27 | #include "cpu/ref_io_helper.hpp" |
28 | #include "cpu/ref_softmax.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | static bool is_padding(const memory_desc_wrapper &md) { |
35 | for (int i = 0; i < md.ndims(); i++) |
36 | if (md.dims()[i] != md.padded_dims()[i]) return true; |
37 | return false; |
38 | } |
39 | |
40 | status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const { |
41 | using namespace memory_tracking::names; |
42 | |
43 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
44 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
45 | |
46 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
47 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
48 | |
49 | float *scratchpad_int8 = ctx.get_scratchpad_grantor().template get<float>( |
50 | key_softmax_interim_store); |
51 | |
52 | const memory_desc_wrapper src_d(pd()->src_md()); |
53 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
54 | |
55 | const auto interim_dt |
56 | = pd()->need_int8_scratchpad() ? data_type::f32 : dst_d.data_type(); |
57 | |
58 | const dim_t ou_stride = pd()->outer_stride(); |
59 | const auto is_inplace = (src == dst); |
60 | const auto has_padding = is_padding(dst_d); |
61 | const auto zero_padding = has_padding && !is_inplace; |
62 | const auto axis = pd()->axis(); |
63 | const auto axis_size = pd()->axis_size(true); |
64 | const auto axis_blk_size = src_d.padded_dims()[axis] - src_d.dims()[axis]; |
65 | const auto src_dt_size = types::data_type_size(pd()->src_md()->data_type); |
66 | const auto dst_dt_size = types::data_type_size(pd()->dst_md()->data_type); |
67 | |
68 | const int nthr = pd()->nthr_; |
69 | |
70 | parallel_nd_ext(nthr, outer_size_, [&](int ithr, int, dim_t ou) { |
71 | const void *src_data = reinterpret_cast<const char *>(src) |
72 | + ou * ou_stride * src_dt_size; |
73 | void *dst_data |
74 | = reinterpret_cast<char *>(dst) + ou * ou_stride * dst_dt_size; |
75 | void *interim_ptr = pd()->need_int8_scratchpad() |
76 | ? (scratchpad_int8 + ithr * axis_size) |
77 | : dst_data; |
78 | |
79 | float space_max = -FLT_MAX; |
80 | float space_denom = 0; |
81 | constexpr int unroll_factor = 32; |
82 | |
83 | // Intel(R) C++ Compiler generates the maxps + shuffle pattern |
84 | // for the max search which works faster |
85 | #if !defined(__INTEL_COMPILER) |
86 | // The code below makes the compiler generate maxps instruction. |
87 | // rather than maxss, which is generated for the 'else' code path |
88 | auto max_wrapper = [](float a, float b) { return nstl::max(a, b); }; |
89 | auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; |
90 | |
91 | if (channels_ < unroll_factor) { |
92 | float max_val = -FLT_MAX; |
93 | for (int i = 0; i < channels_; i++) { |
94 | max_val = max_wrapper(max_val, |
95 | io::load_float_value(src_d.data_type(), src_data, i)); |
96 | } |
97 | space_max = max_val; |
98 | } else { |
99 | float max_values[unroll_factor]; |
100 | |
101 | for (int i = 0; i < unroll_factor; i++) { |
102 | max_values[i] |
103 | = io::load_float_value(src_d.data_type(), src_data, i); |
104 | } |
105 | for (int i = unroll_factor; i < channels_; i += unroll_factor) { |
106 | int offset = min_wrapper(i, channels_ - unroll_factor); |
107 | for (int j = 0; j < unroll_factor; j++) { |
108 | max_values[j] = max_wrapper(max_values[j], |
109 | io::load_float_value( |
110 | src_d.data_type(), src_data, offset + j)); |
111 | } |
112 | } |
113 | float max_val = -FLT_MAX; |
114 | for (int i = 0; i < unroll_factor; i++) { |
115 | max_val = max_wrapper(max_val, max_values[i]); |
116 | } |
117 | space_max = max_val; |
118 | } |
119 | #else |
120 | for (int c = 0; c < channels_; ++c) |
121 | space_max = nstl::max(space_max, |
122 | io::load_float_value(src_d.data_type(), src_data, c)); |
123 | #endif |
124 | |
125 | // sub + exp + sum |
126 | int tail = channels_ % unroll_factor; |
127 | for (int i = 0; i < channels_ - tail; i += unroll_factor) { |
128 | PRAGMA_OMP_SIMD(reduction(+ : space_denom)) |
129 | for (int j = 0; j < unroll_factor; j++) { |
130 | float s = io::load_float_value( |
131 | src_d.data_type(), src_data, i + j); |
132 | float d = s - space_max; |
133 | if (pd()->is_softmax()) { |
134 | d = expf(d); |
135 | space_denom += d; |
136 | } else if (pd()->is_logsoftmax()) { |
137 | space_denom += expf(d); |
138 | } |
139 | |
140 | io::store_float_value(interim_dt, d, interim_ptr, i + j); |
141 | } |
142 | } |
143 | for (int i = channels_ - tail; i < channels_; i++) { |
144 | float s = io::load_float_value(src_d.data_type(), src_data, i); |
145 | float d = s - space_max; |
146 | if (pd()->is_softmax()) { |
147 | d = expf(d); |
148 | space_denom += d; |
149 | } else if (pd()->is_logsoftmax()) { |
150 | space_denom += expf(d); |
151 | } |
152 | io::store_float_value(interim_dt, d, interim_ptr, i); |
153 | } |
154 | |
155 | // scal |
156 | if (pd()->is_softmax()) { |
157 | space_denom = space_denom ? (1.f / space_denom) : 1.f; |
158 | } else if (pd()->is_logsoftmax()) { |
159 | space_denom = logf(space_denom); |
160 | } |
161 | for (int c = 0; c < channels_; ++c) { |
162 | float d = io::load_float_value(interim_dt, interim_ptr, c); |
163 | float val = 0; |
164 | if (pd()->is_softmax()) { |
165 | val = d * space_denom; |
166 | } else if (pd()->is_logsoftmax()) { |
167 | val = d - space_denom; |
168 | } |
169 | val *= src_scales[0] * dst_scales[0]; |
170 | io::store_float_value(dst_d.data_type(), val, dst_data, c); |
171 | } |
172 | if (zero_padding) { |
173 | PRAGMA_OMP_SIMD() |
174 | for (int i = 0; i < axis_blk_size; i++) |
175 | io::store_float_value( |
176 | dst_d.data_type(), 0, dst_data, channels_ + i); |
177 | } |
178 | }); |
179 | return status::success; |
180 | } |
181 | |
182 | status_t ref_softmax_fwd_t::execute_forward_generic( |
183 | const exec_ctx_t &ctx) const { |
184 | using namespace memory_tracking::names; |
185 | |
186 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
187 | auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); |
188 | |
189 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
190 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
191 | |
192 | float *scratchpad_int8 = ctx.get_scratchpad_grantor().template get<float>( |
193 | key_softmax_interim_store); |
194 | |
195 | const memory_desc_wrapper src_d(pd()->src_md()); |
196 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
197 | |
198 | void *interim_ptr = pd()->need_int8_scratchpad() ? scratchpad_int8 : dst; |
199 | const auto interim_dt |
200 | = pd()->need_int8_scratchpad() ? data_type::f32 : dst_d.data_type(); |
201 | |
202 | const auto is_inplace = (src == dst); |
203 | const auto has_padding = is_padding(dst_d); |
204 | if (has_padding && !is_inplace) { |
205 | if (dst_d.is_dense(true)) { |
206 | const auto res = std::div(static_cast<int>(dst_d.size()), PAGE_4K); |
207 | if (!res.quot) |
208 | std::memset(dst, 0, res.rem); |
209 | else |
210 | parallel_nd(res.quot, [&](dim_t i) { |
211 | const auto tail = (i + 1 == res.quot) ? res.rem : 0; |
212 | const auto ptr_dst = reinterpret_cast<unsigned char *>(dst) |
213 | + i * PAGE_4K; |
214 | std::memset(ptr_dst, 0, PAGE_4K + tail); |
215 | }); |
216 | } else |
217 | // needed for submemory correctness |
218 | ctx.zero_pad_output(DNNL_ARG_DST); |
219 | } |
220 | |
221 | const auto axis_size = pd()->axis_size(true); |
222 | const int nthr = pd()->nthr_; |
223 | |
224 | parallel_nd_ext(nthr, outer_size_, [&](int ithr, int, dim_t ou) { |
225 | const dim_t thr_shift = ithr * axis_size; |
226 | |
227 | float space_max_val = 0, space_denom_val = 0; |
228 | float *space_max = &space_max_val, *space_denom = &space_denom_val; |
229 | if (inner_size_ > 1) { |
230 | space_max = ctx.get_scratchpad_grantor().template get<float>( |
231 | key_softmax_reduction) |
232 | + ou * 2 * inner_size_; |
233 | space_denom = space_max + inner_size_; |
234 | } |
235 | |
236 | utils::array_set(space_max, -FLT_MAX, inner_size_); |
237 | utils::array_set(space_denom, 0, inner_size_); |
238 | |
239 | for (int in = 0; in < inner_size_; in++) { |
240 | dim_t ou_in_offset = ou * channels_ * inner_size_ + in; |
241 | |
242 | for (int c = 0; c < channels_; c++) { |
243 | size_t off = src_d.off_l(ou_in_offset + c * inner_size_); |
244 | float s = io::load_float_value(src_d.data_type(), src, off); |
245 | space_max[in] = nstl::max(space_max[in], s); |
246 | } |
247 | |
248 | for (int c = 0; c < channels_; c++) { |
249 | size_t src_off = src_d.off_l(ou_in_offset + c * inner_size_); |
250 | float s = io::load_float_value(src_d.data_type(), src, src_off); |
251 | float d = s - space_max[in]; |
252 | if (pd()->is_softmax()) { |
253 | d = expf(d); |
254 | space_denom[in] += d; |
255 | } else if (pd()->is_logsoftmax()) { |
256 | space_denom[in] += expf(d); |
257 | } |
258 | size_t dst_off = dst_d.off_l(ou_in_offset + c * inner_size_); |
259 | size_t interim_off = pd()->need_int8_scratchpad() |
260 | ? thr_shift + c |
261 | : dst_off; |
262 | io::store_float_value(interim_dt, d, interim_ptr, interim_off); |
263 | } |
264 | |
265 | if (pd()->is_logsoftmax()) { |
266 | space_denom[in] = logf(space_denom[in]); |
267 | } |
268 | |
269 | for (int c = 0; c < channels_; c++) { |
270 | size_t dst_off = dst_d.off_l(ou_in_offset + c * inner_size_); |
271 | size_t interim_off = pd()->need_int8_scratchpad() |
272 | ? thr_shift + c |
273 | : dst_off; |
274 | float d = io::load_float_value( |
275 | interim_dt, interim_ptr, interim_off); |
276 | float sd = space_denom[in]; |
277 | if (pd()->is_softmax()) { |
278 | d /= sd; |
279 | } else if (pd()->is_logsoftmax()) { |
280 | d -= sd; |
281 | } |
282 | d *= src_scales[0] * dst_scales[0]; |
283 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
284 | } |
285 | } |
286 | }); |
287 | return status::success; |
288 | } |
289 | |
290 | // softmax along last physical dimension |
291 | status_t ref_softmax_bwd_t::execute_backward_dense( |
292 | const exec_ctx_t &ctx) const { |
293 | auto dst = CTX_IN_MEM(const void *, DNNL_ARG_DST); |
294 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
295 | auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC); |
296 | |
297 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
298 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
299 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
300 | |
301 | const auto ou_stride = pd()->outer_stride(); |
302 | |
303 | parallel_nd(outer_size_, [&](dim_t ou) { |
304 | float sbr = 0; |
305 | size_t off = ou * ou_stride; |
306 | if (pd()->is_softmax()) { |
307 | for (size_t loff = off; loff < off + channels_; ++loff) { |
308 | float d = io::load_float_value(dst_d.data_type(), dst, loff); |
309 | float dd = io::load_float_value( |
310 | diff_dst_d.data_type(), diff_dst, loff); |
311 | sbr += dd * d; |
312 | } |
313 | for (size_t loff = off; loff < off + channels_; ++loff) { |
314 | float d = io::load_float_value(dst_d.data_type(), dst, loff); |
315 | float dd = io::load_float_value( |
316 | diff_dst_d.data_type(), diff_dst, loff); |
317 | float val = d * (dd - sbr); |
318 | io::store_float_value( |
319 | diff_src_d.data_type(), val, diff_src, loff); |
320 | } |
321 | } else if (pd()->is_logsoftmax()) { |
322 | for (size_t loff = off; loff < off + channels_; ++loff) { |
323 | float dd = io::load_float_value( |
324 | diff_dst_d.data_type(), diff_dst, loff); |
325 | sbr += dd; |
326 | } |
327 | for (size_t loff = off; loff < off + channels_; ++loff) { |
328 | float d = io::load_float_value(dst_d.data_type(), dst, loff); |
329 | float dd = io::load_float_value( |
330 | diff_dst_d.data_type(), diff_dst, loff); |
331 | float val = dd - expf(d) * sbr; |
332 | io::store_float_value( |
333 | diff_src_d.data_type(), val, diff_src, loff); |
334 | } |
335 | } |
336 | }); |
337 | return status::success; |
338 | } |
339 | |
340 | status_t ref_softmax_bwd_t::execute_backward_generic( |
341 | const exec_ctx_t &ctx) const { |
342 | auto dst = CTX_IN_MEM(const void *, DNNL_ARG_DST); |
343 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
344 | auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC); |
345 | |
346 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
347 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
348 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
349 | |
350 | const auto is_inplace = (diff_dst == diff_src); |
351 | const auto has_padding = is_padding(diff_dst_d); |
352 | if (has_padding && !is_inplace) { |
353 | if (diff_dst_d.is_dense(true)) { |
354 | const auto res |
355 | = std::div(static_cast<int>(diff_dst_d.size()), PAGE_4K); |
356 | if (!res.quot) |
357 | std::memset(diff_src, 0, res.rem); |
358 | else |
359 | parallel_nd(res.quot, [&](dim_t i) { |
360 | const auto tail = (i + 1 == res.quot) ? res.rem : 0; |
361 | const auto ptr_dst |
362 | = reinterpret_cast<unsigned char *>(diff_src) |
363 | + i * PAGE_4K; |
364 | std::memset(ptr_dst, 0, PAGE_4K + tail); |
365 | }); |
366 | } else |
367 | // needed for submemory correctness |
368 | ctx.zero_pad_output(DNNL_ARG_DIFF_SRC); |
369 | } |
370 | |
371 | parallel_nd(outer_size_, inner_size_, [&](dim_t ou, dim_t in) { |
372 | dim_t ou_in_offset = ou * channels_ * inner_size_ + in; |
373 | float sbr = 0; |
374 | for (int c = 0; c < channels_; ++c) { |
375 | auto diff_dst_off |
376 | = diff_dst_d.off_l(ou_in_offset + c * inner_size_); |
377 | float dd = io::load_float_value( |
378 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
379 | if (pd()->is_softmax()) { |
380 | auto dst_off = dst_d.off_l(ou_in_offset + c * inner_size_); |
381 | float d = io::load_float_value(dst_d.data_type(), dst, dst_off); |
382 | sbr += dd * d; |
383 | } else if (pd()->is_logsoftmax()) { |
384 | sbr += dd; |
385 | } |
386 | } |
387 | |
388 | for (int c = 0; c < channels_; ++c) { |
389 | auto diff_dst_off |
390 | = diff_dst_d.off_l(ou_in_offset + c * inner_size_); |
391 | auto dst_off = dst_d.off_l(ou_in_offset + c * inner_size_); |
392 | float d = io::load_float_value(dst_d.data_type(), dst, dst_off); |
393 | float dd = io::load_float_value( |
394 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
395 | float val = 0; |
396 | if (pd()->is_softmax()) { |
397 | val = d * (dd - sbr); |
398 | } else if (pd()->is_logsoftmax()) { |
399 | val = dd - expf(d) * sbr; |
400 | } |
401 | auto diff_src_off |
402 | = diff_src_d.off_l(ou_in_offset + c * inner_size_); |
403 | io::store_float_value( |
404 | diff_src_d.data_type(), val, diff_src, diff_src_off); |
405 | } |
406 | }); |
407 | return status::success; |
408 | } |
409 | |
410 | } // namespace cpu |
411 | } // namespace impl |
412 | } // namespace dnnl |
413 | |
414 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
415 | |