1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <array> |
18 | #include <cassert> |
19 | #include <cmath> |
20 | |
21 | #include "common/broadcast_strategy.hpp" |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/math_utils.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/ref_io_helper.hpp" |
29 | #include "cpu/ref_prelu.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | using namespace math; |
36 | using namespace data_type; |
37 | |
38 | static constexpr int max_supported_ndims = 5; |
39 | |
40 | static dim_t offset(const memory_desc_wrapper &mem, dims_t dims) { |
41 | const int ndims = mem.ndims(); |
42 | switch (ndims) { |
43 | case 1: return mem.off(dims[0]); |
44 | case 2: return mem.off(dims[0], dims[1]); |
45 | case 3: return mem.off(dims[0], dims[1], dims[2]); |
46 | case 4: return mem.off(dims[0], dims[1], dims[2], dims[3]); |
47 | case 5: return mem.off(dims[0], dims[1], dims[2], dims[3], dims[4]); |
48 | default: assert(!"Unsupported ndims count" ); |
49 | } |
50 | return -1; |
51 | } |
52 | |
53 | static dim_t weights_offset( |
54 | const int mask, const memory_desc_wrapper &mem, dims_t &dims) { |
55 | dims_t dims_w {}; |
56 | std::copy(dims, dims + max_supported_ndims, dims_w); |
57 | utils::apply_mask_on_dims(dims_w, mem.ndims(), mask); |
58 | return offset(mem, dims_w); |
59 | } |
60 | |
61 | static bool is_padding(const memory_desc_wrapper &md) { |
62 | for (int i = 0; i < md.ndims(); i++) |
63 | if (md.dims()[i] != md.padded_dims()[i]) return true; |
64 | return false; |
65 | } |
66 | |
67 | status_t ref_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
68 | if (pd()->has_zero_dim_memory()) return status::success; |
69 | |
70 | const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC); |
71 | const auto weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS); |
72 | auto dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST); |
73 | |
74 | const memory_desc_wrapper data_d(pd()->src_md(0)); |
75 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
76 | const auto is_inplace = (src == dst); |
77 | const auto has_padding = is_padding(data_d); |
78 | if (has_padding && !is_inplace) ctx.zero_pad_output(DNNL_ARG_TO); |
79 | |
80 | const int mask = utils::get_dims_mask( |
81 | data_d.dims(), weights_d.dims(), data_d.ndims()); |
82 | const dim_t work_amount = data_d.nelems(); |
83 | |
84 | parallel(0, [&](std::size_t ithr, std::size_t nthr) { |
85 | if ((dim_t)ithr >= work_amount) return; |
86 | |
87 | dim_t start {0}, end {0}; |
88 | dims_t dims_d, off; |
89 | for (int i = 0; i < max_supported_ndims; i++) { |
90 | off[i] = 0; |
91 | dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1; |
92 | } |
93 | |
94 | balance211(work_amount, nthr, ithr, start, end); |
95 | utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1], |
96 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
97 | |
98 | for (dim_t iwork = start; iwork < end; ++iwork) { |
99 | const auto data_off = offset(data_d, off); |
100 | const auto weight_off = weights_offset(mask, weights_d, off); |
101 | const float src_val |
102 | = io::load_float_value(data_d.data_type(), src, data_off); |
103 | const float weights_val = io::load_float_value( |
104 | weights_d.data_type(), weights, weight_off); |
105 | |
106 | const float res = relu_fwd(src_val, weights_val); |
107 | |
108 | io::store_float_value(data_d.data_type(), res, dst, data_off); |
109 | utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1], |
110 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
111 | } |
112 | }); |
113 | return status::success; |
114 | } |
115 | |
116 | static float reduce(float *mem, dim_t size) { |
117 | bool tail = size % 2; |
118 | const auto reduce_iteration = [&](float *mem) { |
119 | const auto div_res = std::div(size, (dim_t)2); |
120 | tail = div_res.rem; |
121 | size = div_res.quot; |
122 | if (!tail && !size) { |
123 | mem[0] = 0; |
124 | return; |
125 | } |
126 | dim_t i {0}, off {0}; |
127 | if (tail) { |
128 | if (size) mem[0] += mem[1 + off] + mem[2 + off]; |
129 | ++off; |
130 | ++i; |
131 | } |
132 | for (; i < size; i++) { |
133 | mem[i] = mem[2 * i + off] + mem[(2 * i + 1) + off]; |
134 | } |
135 | }; |
136 | while (size > 1) { |
137 | reduce_iteration(mem); |
138 | } |
139 | return mem[0]; |
140 | } |
141 | |
142 | namespace prelu { |
143 | void set_reduction_buffers( |
144 | const dim_t work_amount, dim_t &group_size, dim_t &buf_size) { |
145 | float sqrt = std::sqrt(work_amount); |
146 | group_size = std::ceil(sqrt); |
147 | buf_size = std::floor(sqrt); |
148 | if (group_size * buf_size < work_amount) group_size++; |
149 | } |
150 | |
151 | dim_t get_scalar_scratchpad_offset(const std::size_t ithr, |
152 | const std::size_t nthr, const dim_t work_amount) { |
153 | dim_t offset {0}, group_size, buf_size; |
154 | for (std::size_t i = 0; i < ithr; i++) { |
155 | dim_t start {0}, end {0}; |
156 | balance211(work_amount, nthr, i, start, end); |
157 | const dim_t workload = end - start; |
158 | set_reduction_buffers(workload, group_size, buf_size); |
159 | offset += buf_size; |
160 | offset += group_size; |
161 | } |
162 | return offset; |
163 | } |
164 | } // namespace prelu |
165 | |
166 | float ref_prelu_bwd_t::ker(const byte *src, const byte *weights, |
167 | const byte *diff_dst, byte *diff_src, dim_t data_off, |
168 | dim_t weight_off) const { |
169 | |
170 | const auto dtype = pd()->src_md(0)->data_type; |
171 | const auto wtype = pd()->weights_md(0)->data_type; |
172 | const float src_val = io::load_float_value(dtype, src, data_off); |
173 | const float diff_dst_val = io::load_float_value(dtype, diff_dst, data_off); |
174 | const float weights_val = io::load_float_value(wtype, weights, weight_off); |
175 | |
176 | const float diff_src_res |
177 | = relu_bwd_use_dst(diff_dst_val, src_val, weights_val); |
178 | const float diff_weight_res = src_val > 0 ? 0 : (diff_dst_val * src_val); |
179 | |
180 | io::store_float_value(dtype, diff_src_res, diff_src, data_off); |
181 | |
182 | return diff_weight_res; |
183 | } |
184 | |
185 | void ref_prelu_bwd_t::calculate_scalar(const byte *src, const byte *weights, |
186 | byte *diff_weights, const byte *diff_dst, byte *diff_src, |
187 | float *scratchpad_buf) const { |
188 | |
189 | const memory_desc_wrapper data_d(pd()->src_md(0)); |
190 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
191 | |
192 | const int nthr = pd()->nthr_; |
193 | const dim_t work_amount = data_d.nelems(); |
194 | const int thread_count = nstl::min((dim_t)nthr, work_amount); |
195 | |
196 | std::vector<float> buf_nthr_partial_results(nthr); |
197 | |
198 | parallel(nthr, [&](std::size_t ithr, std::size_t nthr) { |
199 | if ((dim_t)ithr >= work_amount) return; |
200 | |
201 | dim_t start {0}, end {0}; |
202 | dims_t dims_d, off; |
203 | for (int i = 0; i < max_supported_ndims; i++) { |
204 | off[i] = 0; |
205 | dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1; |
206 | } |
207 | |
208 | balance211(work_amount, nthr, ithr, start, end); |
209 | const dim_t workload = end - start; |
210 | |
211 | utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1], |
212 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
213 | |
214 | dim_t group_size, buf_size; |
215 | prelu::set_reduction_buffers(workload, group_size, buf_size); |
216 | |
217 | const dim_t scratchpad_offset |
218 | = prelu::get_scalar_scratchpad_offset(ithr, nthr, work_amount); |
219 | auto *buf = &scratchpad_buf[scratchpad_offset]; |
220 | auto *group_buf = &scratchpad_buf[scratchpad_offset + buf_size]; |
221 | |
222 | dim_t offset_buf {0}, group_off {0}, data_size {buf_size}; |
223 | for (dim_t iwork = start; iwork < end; ++iwork) { |
224 | const auto data_off = offset(data_d, off); |
225 | const auto weight_off = 0; |
226 | buf[offset_buf] = ker( |
227 | src, weights, diff_dst, diff_src, data_off, weight_off); |
228 | if (++offset_buf == data_size) { |
229 | group_buf[group_off++] = reduce(buf, offset_buf); |
230 | offset_buf = 0; |
231 | data_size = ((group_off + 1) * buf_size <= workload) |
232 | ? buf_size |
233 | : workload - (group_off * buf_size); |
234 | } |
235 | utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1], |
236 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
237 | } |
238 | buf_nthr_partial_results[ithr] = reduce(group_buf, group_size); |
239 | }); |
240 | io::store_float_value(weights_d.data_type(), |
241 | reduce(&buf_nthr_partial_results[0], thread_count), diff_weights, |
242 | 0); |
243 | } |
244 | |
245 | void ref_prelu_bwd_t::calculate_no_broadcast(const byte *src, |
246 | const byte *weights, byte *diff_weights, const byte *diff_dst, |
247 | byte *diff_src, float *scratchpad_buf) const { |
248 | |
249 | const memory_desc_wrapper data_d(pd()->src_md(0)); |
250 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
251 | |
252 | const int nthr = pd()->nthr_; |
253 | const dim_t work_amount = data_d.nelems(); |
254 | const int mask = utils::get_dims_mask( |
255 | data_d.dims(), weights_d.dims(), data_d.ndims()); |
256 | |
257 | parallel(nthr, [&](std::size_t ithr, std::size_t nthr) { |
258 | if ((dim_t)ithr >= work_amount) return; |
259 | |
260 | dim_t start {0}, end {0}; |
261 | dims_t dims_d, off; |
262 | for (int i = 0; i < max_supported_ndims; i++) { |
263 | off[i] = 0; |
264 | dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1; |
265 | } |
266 | |
267 | balance211(work_amount, nthr, ithr, start, end); |
268 | utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1], |
269 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
270 | |
271 | for (dim_t iwork = start; iwork < end; ++iwork) { |
272 | const auto data_off = offset(data_d, off); |
273 | const auto weight_off = weights_offset(mask, weights_d, off); |
274 | const auto res = ker( |
275 | src, weights, diff_dst, diff_src, data_off, weight_off); |
276 | |
277 | io::store_float_value( |
278 | weights_d.data_type(), res, diff_weights, weight_off); |
279 | utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1], |
280 | off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]); |
281 | } |
282 | }); |
283 | } |
284 | |
285 | void ref_prelu_bwd_t::calculate_shared_axes(const byte *src, |
286 | const byte *weights, byte *diff_weights, const byte *diff_dst, |
287 | byte *diff_src, float *scratchpad_buf) const { |
288 | |
289 | const memory_desc_wrapper data_d(pd()->src_md(0)); |
290 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
291 | |
292 | dims_t dims_d, dims_w; |
293 | for (int i = 0; i < max_supported_ndims; i++) { |
294 | dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1; |
295 | dims_w[i] = (weights_d.dims()[i] != 0) ? weights_d.dims()[i] : 1; |
296 | } |
297 | |
298 | const int nthr = pd()->nthr_; |
299 | const dim_t work_amount = weights_d.nelems(); |
300 | |
301 | parallel(nthr, [&](std::size_t ithr, std::size_t nthr) { |
302 | if ((dim_t)ithr >= work_amount) return; |
303 | |
304 | dim_t start {0}, end {0}; |
305 | balance211(work_amount, nthr, ithr, start, end); |
306 | |
307 | dim_t group_size, buf_size; |
308 | const dim_t workload = data_d.nelems() / weights_d.nelems(); |
309 | prelu::set_reduction_buffers(workload, group_size, buf_size); |
310 | dim_t scratchpad_offset = (buf_size + group_size) * ithr; |
311 | auto *buf = &scratchpad_buf[scratchpad_offset]; |
312 | auto *group_buf = &scratchpad_buf[scratchpad_offset + buf_size]; |
313 | |
314 | dims_t off_w, off_d, dims_start, dims_end; |
315 | utils::nd_iterator_init(start, off_w[0], dims_w[0], off_w[1], dims_w[1], |
316 | off_w[2], dims_w[2], off_w[3], dims_w[3], off_w[4], dims_w[4]); |
317 | |
318 | for (dim_t iwork = start; iwork < end; ++iwork) { |
319 | const auto weight_off = offset(weights_d, off_w); |
320 | for (int i = 0; i < max_supported_ndims; i++) { |
321 | dims_start[i] = (dims_d[i] == dims_w[i]) ? off_w[i] : 0; |
322 | dims_end[i] |
323 | = (dims_d[i] == dims_w[i]) ? off_w[i] + 1 : dims_d[i]; |
324 | } |
325 | dim_t buf_off {0}, group_off {0}, data_size {buf_size}; |
326 | for_(off_d[0] = dims_start[0]; off_d[0] < dims_end[0]; ++off_d[0]) |
327 | for_(off_d[1] = dims_start[1]; off_d[1] < dims_end[1]; ++off_d[1]) |
328 | for_(off_d[2] = dims_start[2]; off_d[2] < dims_end[2]; ++off_d[2]) |
329 | for_(off_d[3] = dims_start[3]; off_d[3] < dims_end[3]; ++off_d[3]) |
330 | for (off_d[4] = dims_start[4]; off_d[4] < dims_end[4]; ++off_d[4]) { |
331 | const auto data_off = offset(data_d, off_d); |
332 | const auto diff_weight = ker( |
333 | src, weights, diff_dst, diff_src, data_off, weight_off); |
334 | buf[buf_off] = diff_weight; |
335 | if (++buf_off == data_size) { |
336 | group_buf[group_off++] = reduce(buf, buf_off); |
337 | buf_off = 0; |
338 | data_size = ((group_off + 1) * buf_size <= workload) |
339 | ? buf_size |
340 | : workload - (group_off * buf_size); |
341 | } |
342 | } |
343 | io::store_float_value(weights_d.data_type(), |
344 | reduce(group_buf, group_size), diff_weights, weight_off); |
345 | utils::nd_iterator_step(off_w[0], dims_w[0], off_w[1], dims_w[1], |
346 | off_w[2], dims_w[2], off_w[3], dims_w[3], off_w[4], |
347 | dims_w[4]); |
348 | } |
349 | }); |
350 | } |
351 | |
352 | status_t ref_prelu_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
353 | |
354 | if (pd()->has_zero_dim_memory()) return status::success; |
355 | |
356 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
357 | auto scratchpad_buf = scratchpad.template get<float>( |
358 | memory_tracking::names::key_prelu_reduction); |
359 | |
360 | const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC); |
361 | const auto weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS); |
362 | auto diff_weights = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_WEIGHTS); |
363 | const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST); |
364 | auto diff_src = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC); |
365 | |
366 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
367 | const memory_desc_wrapper data_d(pd()->src_md(0)); |
368 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md(0)); |
369 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
370 | const auto bcast_type = dnnl::impl::get_rhs_arg_broadcasting_strategy( |
371 | *weights_d.md_, data_d); |
372 | |
373 | const auto is_inplace = (diff_src == diff_dst); |
374 | if (is_padding(diff_src_d) && !is_inplace) |
375 | ctx.zero_pad_output(DNNL_ARG_DIFF_SRC); |
376 | |
377 | if (is_padding(diff_weights_d)) ctx.zero_pad_output(DNNL_ARG_DIFF_WEIGHTS); |
378 | |
379 | switch (bcast_type) { |
380 | case broadcasting_strategy_t::scalar: |
381 | calculate_scalar(src, weights, diff_weights, diff_dst, diff_src, |
382 | scratchpad_buf); |
383 | break; |
384 | case broadcasting_strategy_t::no_broadcast: |
385 | calculate_no_broadcast(src, weights, diff_weights, diff_dst, |
386 | diff_src, scratchpad_buf); |
387 | break; |
388 | case broadcasting_strategy_t::per_oc: |
389 | case broadcasting_strategy_t::per_oc_spatial: |
390 | case broadcasting_strategy_t::per_mb_spatial: |
391 | case broadcasting_strategy_t::per_mb_w: |
392 | case broadcasting_strategy_t::per_w: |
393 | case broadcasting_strategy_t::shared_axes: |
394 | calculate_shared_axes(src, weights, diff_weights, diff_dst, |
395 | diff_src, scratchpad_buf); |
396 | break; |
397 | default: assert(!"unsupported broadcast type" ); |
398 | } |
399 | return status::success; |
400 | } |
401 | |
402 | } // namespace cpu |
403 | } // namespace impl |
404 | } // namespace dnnl |
405 | |
406 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
407 | |